diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py index 8c0f0e1d57fb3..1736f145de02b 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py @@ -209,10 +209,19 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): @classmethod def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + is_static = c.activation_quant_key.scale.static + per_tensor_activation_scales = ( c.activation_quant_key.scale.group_shape.is_per_tensor() ) per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor() + + if not is_static: + return ( + False, + "ChannelWiseTorchScaledMMLinearKernel requires static scales", + ) + if per_tensor_activation_scales and per_tensor_weight_scales: return ( False, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py index bcc92ef209af5..493507ba4313a 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py @@ -89,7 +89,6 @@ if current_platform.is_rocm(): class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel): @classmethod def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: - # TODO: check if this causes an issue on non-ROCM platforms from vllm.platforms.rocm import on_mi3xx per_tensor_activation_scales = (