diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index f4ec97804fd48..1d0e36a3fc551 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -86,7 +86,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): act_q_group_shape=self.act_q_group_shape, weight_quant_strategy=weight_quant_strategy, out_dtype=self.out_dtype, - module_name=self.__class__.__name__ + module_name=self.__class__.__name__, ) @classmethod diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index 6ba18e59e4d54..fb16681f03a0c 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -18,6 +18,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + init_fp8_linear_kernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( + ScaledMMLinearQuantStrategy, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin, @@ -96,6 +102,14 @@ class FBGEMMFp8LinearMethod(LinearMethodBase): ) self.out_dtype = torch.get_default_dtype() + self.fp8_linear_kernel = init_fp8_linear_kernel( + act_q_static=False, + act_q_group_shape=GroupShape.PER_TOKEN, + weight_quant_strategy=ScaledMMLinearQuantStrategy.CHANNEL, + out_dtype=self.out_dtype, + module_name=self.__class__.__name__, + ) + def create_weights( self, layer: torch.nn.Module, @@ -184,12 +198,4 @@ class FBGEMMFp8LinearMethod(LinearMethodBase): bias=bias, ) - return self.fp8_linear.apply( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - out_dtype=self.out_dtype, - input_scale=None, - input_scale_ub=layer.input_scale_ub, - bias=bias, - ) + return self.fp8_linear_kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 91988a18fda71..484a8d7ab3af7 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -398,7 +398,7 @@ class Fp8LinearMethod(LinearMethodBase): act_q_group_shape=self.act_q_group_shape, weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, out_dtype=self.out_dtype, - module_name=self.__class__.__name__ + module_name=self.__class__.__name__, ) def create_weights( diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 629ed790b9666..26baba602945d 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -15,6 +15,11 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import ( CutlassScaledMMLinearKernel, ) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import ( + ChannelWiseTorchScaledMMLinearKernel, + PerTensorTorchScaledMMLinearKernel, + RowWiseTorchScaledMMLinearKernel, +) from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import ( ROCmScaledMMLinearKernel, ) @@ -25,19 +30,14 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKer ScaledMMLinearLayerConfig, ScaledMMLinearQuantStrategy, ) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import ( - ChannelWiseTorchScaledMMLinearKernel, - PerTensorTorchScaledMMLinearKernel, - RowWiseTorchScaledMMLinearKernel, -) from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import ( TritonScaledMMLinearKernel, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import ( XLAScaledMMLinearKernel, ) -from vllm.platforms import PlatformEnum, current_platform from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.platforms import PlatformEnum, current_platform logger = init_logger(__name__) @@ -133,7 +133,7 @@ def choose_scaled_mm_linear_kernel( def init_fp8_linear_kernel( act_q_static: bool, act_q_group_shape: GroupShape, - weight_quant_strategy: ScaledMMLinearQuantStrategy, + weight_quant_strategy: ScaledMMLinearQuantStrategy, out_dtype: torch.dtype, module_name: str, ) -> FP8ScaledMMLinearKernel: diff --git a/vllm/model_executor/layers/quantization/ptpc_fp8.py b/vllm/model_executor/layers/quantization/ptpc_fp8.py index 5352ba9c45009..2634bbd4bd87e 100644 --- a/vllm/model_executor/layers/quantization/ptpc_fp8.py +++ b/vllm/model_executor/layers/quantization/ptpc_fp8.py @@ -108,10 +108,9 @@ class PTPCFp8LinearMethod(Fp8LinearMethod): act_q_group_shape=GroupShape.PER_TOKEN, weight_quant_strategy=ScaledMMLinearQuantStrategy.CHANNEL, out_dtype=self.out_dtype, - module_name=self.__class__.__name__ + module_name=self.__class__.__name__, ) - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) @@ -136,6 +135,4 @@ class PTPCFp8LinearMethod(Fp8LinearMethod): x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - return self.fp8_linear_kernel.apply_weights( - layer, x, bias - ) + return self.fp8_linear_kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index 94a90747e2814..f32c14e27f68f 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -178,7 +178,7 @@ class QuarkW8A8Fp8(QuarkScheme): act_q_group_shape=self.act_quant_group_shape, weight_quant_strategy=weight_quant_strategy, out_dtype=self.out_dtype, - module_name=self.__class__.__name__ + module_name=self.__class__.__name__, ) def apply_weights(