diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index b36b77109e922..67d0772895786 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -17,6 +17,10 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import ( CutlassFP8ScaledMMLinearKernel, CutlassScaledMMLinearKernel, ) + +from vllm.model_executor.layers.quantization.kernels.scaled_mm.flashinfer import ( + FlashInferScaledMMLinearKernel +) from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import ( ChannelWiseTorchScaledMMLinearKernel, PerTensorTorchScaledMMLinearKernel, @@ -54,7 +58,13 @@ _POSSIBLE_INT8_KERNELS: dict[PlatformEnum, list[type[Int8ScaledMMLinearKernel]]] # in priority/performance order (when available) _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] = { - PlatformEnum.CUDA: [CutlassFP8ScaledMMLinearKernel], + PlatformEnum.CUDA: [ + FlashInferScaledMMLinearKernel, + CutlassFP8ScaledMMLinearKernel, + PerTensorTorchScaledMMLinearKernel, + RowWiseTorchScaledMMLinearKernel, + ChannelWiseTorchScaledMMLinearKernel, + ], PlatformEnum.ROCM: [ ROCmScaledMMLinearKernel, PerTensorTorchScaledMMLinearKernel,