ensure static scales for ChannelWiseTorchScaledMMLinearKernel; remove comment

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm 2025-11-07 12:14:54 +00:00
parent e47d55b80f
commit edb6d43a37
2 changed files with 9 additions and 1 deletions

View File

@ -209,10 +209,19 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
is_static = c.activation_quant_key.scale.static
per_tensor_activation_scales = (
c.activation_quant_key.scale.group_shape.is_per_tensor()
)
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
if not is_static:
return (
False,
"ChannelWiseTorchScaledMMLinearKernel requires static scales",
)
if per_tensor_activation_scales and per_tensor_weight_scales:
return (
False,

View File

@ -89,7 +89,6 @@ if current_platform.is_rocm():
class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel):
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
# TODO: check if this causes an issue on non-ROCM platforms
from vllm.platforms.rocm import on_mi3xx
per_tensor_activation_scales = (