mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 17:07:04 +08:00
add dynamic per tensor fallback
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
10eebd4896
commit
cbfcff373b
@ -77,6 +77,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
is_layer_skipped,
|
||||
kFp8DynamicTensorSym,
|
||||
kFp8DynamicTokenSym,
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
@ -381,9 +382,12 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
if not self.act_q_static and cutlass_fp8_supported():
|
||||
self.act_q_group_shape = GroupShape.PER_TOKEN
|
||||
self.activation_quant_key = kFp8DynamicTokenSym
|
||||
else:
|
||||
elif self.act_q_static:
|
||||
self.act_q_group_shape = GroupShape.PER_TENSOR
|
||||
self.activation_quant_key = kFp8StaticTensorSym
|
||||
else:
|
||||
self.act_q_group_shape = GroupShape.PER_TENSOR
|
||||
self.activation_quant_key = kFp8DynamicTensorSym
|
||||
|
||||
if self.block_quant:
|
||||
assert not self.act_q_static
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user