add CPU kernels; fix fp8 quant type selection

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm 2025-11-14 07:46:37 +00:00
parent 9784a0c414
commit 10eebd4896
2 changed files with 7 additions and 4 deletions

View File

@ -77,9 +77,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
is_layer_skipped,
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
kFp8StaticTokenSym,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d,
@ -381,10 +380,10 @@ class Fp8LinearMethod(LinearMethodBase):
# Use per-token quantization for better perf if dynamic and cutlass
if not self.act_q_static and cutlass_fp8_supported():
self.act_q_group_shape = GroupShape.PER_TOKEN
self.activation_quant_key = kFp8StaticTokenSym
self.activation_quant_key = kFp8DynamicTokenSym
else:
self.act_q_group_shape = GroupShape.PER_TENSOR
self.activation_quant_key = kFp8DynamicTensorSym
self.activation_quant_key = kFp8StaticTensorSym
if self.block_quant:
assert not self.act_q_static

View File

@ -69,6 +69,10 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] =
RowWiseTorchScaledMMLinearKernel,
ChannelWiseTorchScaledMMLinearKernel,
],
PlatformEnum.CPU: [
PerTensorTorchScaledMMLinearKernel,
ChannelWiseTorchScaledMMLinearKernel,
],
}
_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel)