From edb6d43a371ea3a7e425c273950ab4af1ccff0f8 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Fri, 7 Nov 2025 12:14:54 +0000 Subject: [PATCH] ensure static scales for ChannelWiseTorchScaledMMLinearKernel; remove comment Signed-off-by: vllmellm --- .../layers/quantization/kernels/scaled_mm/pytorch.py | 9 +++++++++ .../layers/quantization/kernels/scaled_mm/rocm.py | 1 - 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py index 8c0f0e1d57fb3..1736f145de02b 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py @@ -209,10 +209,19 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel): @classmethod def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + is_static = c.activation_quant_key.scale.static + per_tensor_activation_scales = ( c.activation_quant_key.scale.group_shape.is_per_tensor() ) per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor() + + if not is_static: + return ( + False, + "ChannelWiseTorchScaledMMLinearKernel requires static scales", + ) + if per_tensor_activation_scales and per_tensor_weight_scales: return ( False, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py index bcc92ef209af5..493507ba4313a 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py @@ -89,7 +89,6 @@ if current_platform.is_rocm(): class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel): @classmethod def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: - # TODO: check if this causes an issue on non-ROCM platforms from vllm.platforms.rocm import on_mi3xx per_tensor_activation_scales = (