mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-04 02:09:08 +08:00
add CPU kernels; fix fp8 quant type selection
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
9784a0c414
commit
10eebd4896
@ -77,9 +77,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
|||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
GroupShape,
|
GroupShape,
|
||||||
is_layer_skipped,
|
is_layer_skipped,
|
||||||
kFp8DynamicTensorSym,
|
kFp8DynamicTokenSym,
|
||||||
kFp8StaticTensorSym,
|
kFp8StaticTensorSym,
|
||||||
kFp8StaticTokenSym,
|
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
all_close_1d,
|
all_close_1d,
|
||||||
@ -381,10 +380,10 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
# Use per-token quantization for better perf if dynamic and cutlass
|
# Use per-token quantization for better perf if dynamic and cutlass
|
||||||
if not self.act_q_static and cutlass_fp8_supported():
|
if not self.act_q_static and cutlass_fp8_supported():
|
||||||
self.act_q_group_shape = GroupShape.PER_TOKEN
|
self.act_q_group_shape = GroupShape.PER_TOKEN
|
||||||
self.activation_quant_key = kFp8StaticTokenSym
|
self.activation_quant_key = kFp8DynamicTokenSym
|
||||||
else:
|
else:
|
||||||
self.act_q_group_shape = GroupShape.PER_TENSOR
|
self.act_q_group_shape = GroupShape.PER_TENSOR
|
||||||
self.activation_quant_key = kFp8DynamicTensorSym
|
self.activation_quant_key = kFp8StaticTensorSym
|
||||||
|
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
assert not self.act_q_static
|
assert not self.act_q_static
|
||||||
|
|||||||
@ -69,6 +69,10 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] =
|
|||||||
RowWiseTorchScaledMMLinearKernel,
|
RowWiseTorchScaledMMLinearKernel,
|
||||||
ChannelWiseTorchScaledMMLinearKernel,
|
ChannelWiseTorchScaledMMLinearKernel,
|
||||||
],
|
],
|
||||||
|
PlatformEnum.CPU: [
|
||||||
|
PerTensorTorchScaledMMLinearKernel,
|
||||||
|
ChannelWiseTorchScaledMMLinearKernel,
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel)
|
_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user