mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-10 08:08:10 +08:00
revert input scale upper bounds
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
7d361487f7
commit
1f65cd56e5
@ -48,6 +48,7 @@ FP8ParamsT = tuple[
|
|||||||
torch.Tensor, # weight
|
torch.Tensor, # weight
|
||||||
torch.Tensor, # weight_scale
|
torch.Tensor, # weight_scale
|
||||||
torch.Tensor | None, # input_scale,
|
torch.Tensor | None, # input_scale,
|
||||||
|
torch.Tensor | None, # input_scale_ub,
|
||||||
]
|
]
|
||||||
Int8ParamsT = tuple[
|
Int8ParamsT = tuple[
|
||||||
torch.Tensor, # weight
|
torch.Tensor, # weight
|
||||||
@ -122,11 +123,12 @@ class FP8ScaledMMLinearKernel(
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def _get_layer_params(self, layer) -> FP8ParamsT:
|
def _get_layer_params(self, layer) -> FP8ParamsT:
|
||||||
w, w_s, x_s = self.layer_param_names
|
w, w_s, x_s, x_s_ub = self.layer_param_names
|
||||||
return (
|
return (
|
||||||
getattr(layer, w),
|
getattr(layer, w),
|
||||||
getattr(layer, w_s),
|
getattr(layer, w_s),
|
||||||
getattr(layer, x_s),
|
getattr(layer, x_s),
|
||||||
|
getattr(layer, x_s_ub),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -180,7 +180,7 @@ class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: torch.Tensor | None = None,
|
bias: torch.Tensor | None = None,
|
||||||
):
|
):
|
||||||
w, w_s, x_s = self._get_layer_params(layer)
|
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
|
||||||
return apply_weights_fp8(
|
return apply_weights_fp8(
|
||||||
cutlass_w8a8_scaled_mm_fp8,
|
cutlass_w8a8_scaled_mm_fp8,
|
||||||
self.quant_fp8,
|
self.quant_fp8,
|
||||||
@ -189,5 +189,6 @@ class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
|||||||
w_s,
|
w_s,
|
||||||
x_s,
|
x_s,
|
||||||
bias,
|
bias,
|
||||||
|
x_s_ub,
|
||||||
self.config.out_dtype,
|
self.config.out_dtype,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -77,7 +77,7 @@ class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: torch.Tensor | None = None,
|
bias: torch.Tensor | None = None,
|
||||||
):
|
):
|
||||||
w, w_s, x_s = self._get_layer_params(layer)
|
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
|
||||||
return apply_weights_fp8(
|
return apply_weights_fp8(
|
||||||
flashinfer_w8a8_scaled_mm,
|
flashinfer_w8a8_scaled_mm,
|
||||||
self.quant_fp8,
|
self.quant_fp8,
|
||||||
@ -86,5 +86,6 @@ class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
|||||||
w_s,
|
w_s,
|
||||||
x_s,
|
x_s,
|
||||||
bias,
|
bias,
|
||||||
|
x_s_ub,
|
||||||
self.config.out_dtype,
|
self.config.out_dtype,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -162,7 +162,7 @@ class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: torch.Tensor | None = None,
|
bias: torch.Tensor | None = None,
|
||||||
):
|
):
|
||||||
w, w_s, x_s = self._get_layer_params(layer)
|
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
|
||||||
return apply_weights_fp8(
|
return apply_weights_fp8(
|
||||||
torch_per_tensor_w8a8_scaled_mm,
|
torch_per_tensor_w8a8_scaled_mm,
|
||||||
self.quant_fp8,
|
self.quant_fp8,
|
||||||
@ -171,6 +171,7 @@ class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
|||||||
w_s,
|
w_s,
|
||||||
x_s,
|
x_s,
|
||||||
bias,
|
bias,
|
||||||
|
x_s_ub,
|
||||||
self.config.out_dtype,
|
self.config.out_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -215,7 +216,7 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: torch.Tensor | None = None,
|
bias: torch.Tensor | None = None,
|
||||||
):
|
):
|
||||||
w, w_s, x_s = self._get_layer_params(layer)
|
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
|
||||||
return apply_weights_fp8(
|
return apply_weights_fp8(
|
||||||
torch_row_wise_w8a8_scaled_mm,
|
torch_row_wise_w8a8_scaled_mm,
|
||||||
self.quant_fp8,
|
self.quant_fp8,
|
||||||
@ -224,6 +225,7 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
|||||||
w_s,
|
w_s,
|
||||||
x_s,
|
x_s,
|
||||||
bias,
|
bias,
|
||||||
|
x_s_ub,
|
||||||
self.config.out_dtype,
|
self.config.out_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -255,7 +257,7 @@ class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: torch.Tensor | None = None,
|
bias: torch.Tensor | None = None,
|
||||||
):
|
):
|
||||||
w, w_s, x_s = self._get_layer_params(layer)
|
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
|
||||||
return apply_weights_fp8(
|
return apply_weights_fp8(
|
||||||
torch_channelwise_w8a8_scaled_mm,
|
torch_channelwise_w8a8_scaled_mm,
|
||||||
self.quant_fp8,
|
self.quant_fp8,
|
||||||
@ -264,5 +266,6 @@ class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
|||||||
w_s,
|
w_s,
|
||||||
x_s,
|
x_s,
|
||||||
bias,
|
bias,
|
||||||
|
x_s_ub,
|
||||||
self.config.out_dtype,
|
self.config.out_dtype,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -132,7 +132,7 @@ class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: torch.Tensor | None = None,
|
bias: torch.Tensor | None = None,
|
||||||
):
|
):
|
||||||
w, w_s, x_s = self._get_layer_params(layer)
|
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
|
||||||
return apply_weights_fp8(
|
return apply_weights_fp8(
|
||||||
rocm_per_tensor_float_w8a8_scaled_mm,
|
rocm_per_tensor_float_w8a8_scaled_mm,
|
||||||
self.quant_fp8,
|
self.quant_fp8,
|
||||||
@ -141,5 +141,6 @@ class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
|||||||
w_s,
|
w_s,
|
||||||
x_s,
|
x_s,
|
||||||
bias,
|
bias,
|
||||||
|
x_s_ub,
|
||||||
self.config.out_dtype,
|
self.config.out_dtype,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -18,6 +18,7 @@ def apply_weights_fp8(
|
|||||||
w_s: torch.Tensor,
|
w_s: torch.Tensor,
|
||||||
x_s: torch.Tensor,
|
x_s: torch.Tensor,
|
||||||
bias: torch.Tensor,
|
bias: torch.Tensor,
|
||||||
|
x_s_ub: torch.Tensor | None,
|
||||||
maybe_out_dtype: torch.dtype | None,
|
maybe_out_dtype: torch.dtype | None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||||
@ -36,6 +37,7 @@ def apply_weights_fp8(
|
|||||||
x_2d_q, x_s = quant_fp8_func(
|
x_2d_q, x_s = quant_fp8_func(
|
||||||
x_2d,
|
x_2d,
|
||||||
x_s,
|
x_s,
|
||||||
|
x_s_ub,
|
||||||
)
|
)
|
||||||
|
|
||||||
return scaled_mm_func(
|
return scaled_mm_func(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user