diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py index b9acd89f69d82..9798f88b140a5 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py @@ -48,6 +48,7 @@ FP8ParamsT = tuple[ torch.Tensor, # weight torch.Tensor, # weight_scale torch.Tensor | None, # input_scale, + torch.Tensor | None, # input_scale_ub, ] Int8ParamsT = tuple[ torch.Tensor, # weight @@ -122,11 +123,12 @@ class FP8ScaledMMLinearKernel( pass def _get_layer_params(self, layer) -> FP8ParamsT: - w, w_s, x_s = self.layer_param_names + w, w_s, x_s, x_s_ub = self.layer_param_names return ( getattr(layer, w), getattr(layer, w_s), getattr(layer, x_s), + getattr(layer, x_s_ub), ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py index 28348f50fc273..fc8893cb7e1b0 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -180,7 +180,7 @@ class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): x: torch.Tensor, bias: torch.Tensor | None = None, ): - w, w_s, x_s = self._get_layer_params(layer) + w, w_s, x_s, x_s_ub = self._get_layer_params(layer) return apply_weights_fp8( cutlass_w8a8_scaled_mm_fp8, self.quant_fp8, @@ -189,5 +189,6 @@ class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): w_s, x_s, bias, + x_s_ub, self.config.out_dtype, ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py index 8fd2c88857cab..e33b305322043 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/flash_infer.py @@ -77,7 +77,7 @@ class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel): x: torch.Tensor, bias: torch.Tensor | None = None, ): - w, w_s, x_s = self._get_layer_params(layer) + w, w_s, x_s, x_s_ub = self._get_layer_params(layer) return apply_weights_fp8( flashinfer_w8a8_scaled_mm, self.quant_fp8, @@ -86,5 +86,6 @@ class FlashInferScaledMMLinearKernel(FP8ScaledMMLinearKernel): w_s, x_s, bias, + x_s_ub, self.config.out_dtype, ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py index c2a8474ac5b47..c0466e840fc08 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py @@ -162,7 +162,7 @@ class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): x: torch.Tensor, bias: torch.Tensor | None = None, ): - w, w_s, x_s = self._get_layer_params(layer) + w, w_s, x_s, x_s_ub = self._get_layer_params(layer) return apply_weights_fp8( torch_per_tensor_w8a8_scaled_mm, self.quant_fp8, @@ -171,6 +171,7 @@ class PerTensorTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): w_s, x_s, bias, + x_s_ub, self.config.out_dtype, ) @@ -215,7 +216,7 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): x: torch.Tensor, bias: torch.Tensor | None = None, ): - w, w_s, x_s = self._get_layer_params(layer) + w, w_s, x_s, x_s_ub = self._get_layer_params(layer) return apply_weights_fp8( torch_row_wise_w8a8_scaled_mm, self.quant_fp8, @@ -224,6 +225,7 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): w_s, x_s, bias, + x_s_ub, self.config.out_dtype, ) @@ -255,7 +257,7 @@ class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): x: torch.Tensor, bias: torch.Tensor | None = None, ): - w, w_s, x_s = self._get_layer_params(layer) + w, w_s, x_s, x_s_ub = self._get_layer_params(layer) return apply_weights_fp8( torch_channelwise_w8a8_scaled_mm, self.quant_fp8, @@ -264,5 +266,6 @@ class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): w_s, x_s, bias, + x_s_ub, self.config.out_dtype, ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py index 6144a94b7fb91..63744337a7e5a 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py @@ -132,7 +132,7 @@ class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel): x: torch.Tensor, bias: torch.Tensor | None = None, ): - w, w_s, x_s = self._get_layer_params(layer) + w, w_s, x_s, x_s_ub = self._get_layer_params(layer) return apply_weights_fp8( rocm_per_tensor_float_w8a8_scaled_mm, self.quant_fp8, @@ -141,5 +141,6 @@ class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel): w_s, x_s, bias, + x_s_ub, self.config.out_dtype, ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py index ca1a2c5b4f29b..8323690817d62 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/utils.py @@ -18,6 +18,7 @@ def apply_weights_fp8( w_s: torch.Tensor, x_s: torch.Tensor, bias: torch.Tensor, + x_s_ub: torch.Tensor | None, maybe_out_dtype: torch.dtype | None, ) -> torch.Tensor: # ops.scaled_fp8_quant supports both dynamic and static quant. @@ -36,6 +37,7 @@ def apply_weights_fp8( x_2d_q, x_s = quant_fp8_func( x_2d, x_s, + x_s_ub, ) return scaled_mm_func(