mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-22 17:24:28 +08:00
ensure static scales for ChannelWiseTorchScaledMMLinearKernel; remove comment
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
e47d55b80f
commit
edb6d43a37
@ -209,10 +209,19 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
|||||||
class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
class ChannelWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||||
|
is_static = c.activation_quant_key.scale.static
|
||||||
|
|
||||||
per_tensor_activation_scales = (
|
per_tensor_activation_scales = (
|
||||||
c.activation_quant_key.scale.group_shape.is_per_tensor()
|
c.activation_quant_key.scale.group_shape.is_per_tensor()
|
||||||
)
|
)
|
||||||
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
|
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
|
||||||
|
|
||||||
|
if not is_static:
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
"ChannelWiseTorchScaledMMLinearKernel requires static scales",
|
||||||
|
)
|
||||||
|
|
||||||
if per_tensor_activation_scales and per_tensor_weight_scales:
|
if per_tensor_activation_scales and per_tensor_weight_scales:
|
||||||
return (
|
return (
|
||||||
False,
|
False,
|
||||||
|
|||||||
@ -89,7 +89,6 @@ if current_platform.is_rocm():
|
|||||||
class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
class ROCmScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||||
# TODO: check if this causes an issue on non-ROCM platforms
|
|
||||||
from vllm.platforms.rocm import on_mi3xx
|
from vllm.platforms.rocm import on_mi3xx
|
||||||
|
|
||||||
per_tensor_activation_scales = (
|
per_tensor_activation_scales = (
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user