mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-06 11:11:19 +08:00
[Quantization] fix marlin w8a8 check (#30961)
Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
This commit is contained in:
parent
23a1946e3b
commit
5fbfa8d9ef
@ -11,7 +11,6 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
|||||||
marlin_make_workspace_new,
|
marlin_make_workspace_new,
|
||||||
marlin_permute_bias,
|
marlin_permute_bias,
|
||||||
marlin_permute_scales,
|
marlin_permute_scales,
|
||||||
marlin_quant_input,
|
|
||||||
should_use_atomic_add_reduce,
|
should_use_atomic_add_reduce,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.utils import replace_parameter
|
from vllm.model_executor.utils import replace_parameter
|
||||||
@ -63,13 +62,11 @@ def apply_fp8_marlin_linear(
|
|||||||
inputs = reshaped_x
|
inputs = reshaped_x
|
||||||
a_scales = None
|
a_scales = None
|
||||||
if input_dtype is not None and input_dtype.itemsize == 1:
|
if input_dtype is not None and input_dtype.itemsize == 1:
|
||||||
if input_dtype != torch.float8_e4m3fn:
|
# inputs, a_scales = marlin_quant_input(inputs, torch.float8_e4m3fn)
|
||||||
raise RuntimeError("FP8 weight + INT8 activation is not supported.")
|
raise RuntimeError("Marlin W8A8 is not supported.")
|
||||||
|
|
||||||
inputs, a_scales = marlin_quant_input(inputs, torch.float8_e4m3fn)
|
|
||||||
|
|
||||||
output = ops.gptq_marlin_gemm(
|
output = ops.gptq_marlin_gemm(
|
||||||
a=reshaped_x,
|
a=inputs,
|
||||||
c=None,
|
c=None,
|
||||||
b_q_weight=weight,
|
b_q_weight=weight,
|
||||||
b_bias=bias,
|
b_bias=bias,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user