mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 09:07:03 +08:00
ensure static scales for ChannelWiseTorchScaledMMLinearKernel; remove comment
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
e47d55b80f
commit
edb6d43a37
@ -209,10 +209,19 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||
class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
is_static = c.activation_quant_key.scale.static
|
||||
|
||||
per_tensor_activation_scales = (
|
||||
c.activation_quant_key.scale.group_shape.is_per_tensor()
|
||||
)
|
||||
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
|
||||
|
||||
if not is_static:
|
||||
return (
|
||||
False,
|
||||
"ChannelWiseTorchScaledMMLinearKernel requires static scales",
|
||||
)
|
||||
|
||||
if per_tensor_activation_scales and per_tensor_weight_scales:
|
||||
return (
|
||||
False,
|
||||
|
||||
@ -89,7 +89,6 @@ if current_platform.is_rocm():
|
||||
class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
# TODO: check if this causes an issue on non-ROCM platforms
|
||||
from vllm.platforms.rocm import on_mi3xx
|
||||
|
||||
per_tensor_activation_scales = (
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user