diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index a0159c39092c5..67e5b65de6010 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -77,9 +77,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, is_layer_skipped, - kFp8DynamicTensorSym, + kFp8DynamicTokenSym, kFp8StaticTensorSym, - kFp8StaticTokenSym, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, @@ -381,10 +380,10 @@ class Fp8LinearMethod(LinearMethodBase): # Use per-token quantization for better perf if dynamic and cutlass if not self.act_q_static and cutlass_fp8_supported(): self.act_q_group_shape = GroupShape.PER_TOKEN - self.activation_quant_key = kFp8StaticTokenSym + self.activation_quant_key = kFp8DynamicTokenSym else: self.act_q_group_shape = GroupShape.PER_TENSOR - self.activation_quant_key = kFp8DynamicTensorSym + self.activation_quant_key = kFp8StaticTensorSym if self.block_quant: assert not self.act_q_static diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index b033cc7905e4e..36e4a16c01683 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -69,6 +69,10 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] = RowWiseTorchScaledMMLinearKernel, ChannelWiseTorchScaledMMLinearKernel, ], + PlatformEnum.CPU: [ + PerTensorTorchScaledMMLinearKernel, + ChannelWiseTorchScaledMMLinearKernel, + ], } _KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel)