diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 37b682984fc35..f478cd319e667 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -37,6 +37,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + init_fp8_linear_kernel, +) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 + ScaledMMLinearQuantStrategy, +) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( build_flashinfer_fp4_cutlass_moe_prepare_finalize, @@ -68,7 +74,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( swizzle_blockscale, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, requantize_with_max_scale, ) from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter @@ -254,8 +259,12 @@ class ModelOptFp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: ModelOptFp8Config) -> None: self.quant_config = quant_config - self.fp8_linear = Fp8LinearOp( - act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR + self.fp8_linear = init_fp8_linear_kernel( + act_q_static=True, + act_q_group_shape=GroupShape.PER_TENSOR, + weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, + out_dtype=None, + module_name=self.__class__.__name__, ) def create_weights( @@ -323,13 +332,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase): x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - return self.fp8_linear.apply( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - input_scale=layer.input_scale, - bias=bias, - ) + return self.fp8_linear.apply_weights(layer, x, bias) class ModelOptFp8MoEMethod(FusedMoEMethodBase):