mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-20 07:14:59 +08:00
[Quantization] add marlin w4a8/w8a8 check (#31061)
Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
This commit is contained in:
parent
ae0770fa6b
commit
7c73ceb581
@ -594,9 +594,15 @@ def apply_awq_marlin_linear(
|
||||
|
||||
a_scales = None
|
||||
if input_dtype == torch.int8:
|
||||
assert quant_type == scalar_types.uint4, (
|
||||
"W8A8-INT8 is not supported by marlin kernel."
|
||||
)
|
||||
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
|
||||
a_scales = a_scales * input_global_scale
|
||||
elif input_dtype == torch.float8_e4m3fn:
|
||||
assert quant_type == scalar_types.uint4, (
|
||||
"INT8 weight + FP8 activation is not supported."
|
||||
)
|
||||
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
@ -649,9 +655,15 @@ def apply_rtn_marlin_linear(
|
||||
|
||||
a_scales = None
|
||||
if input_dtype == torch.int8:
|
||||
assert quant_type == scalar_types.uint4b8, (
|
||||
"W8A8-INT8 is not supported by marlin kernel."
|
||||
)
|
||||
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
|
||||
a_scales = a_scales * input_global_scale
|
||||
elif input_dtype == torch.float8_e4m3fn:
|
||||
assert quant_type == scalar_types.uint4b8, (
|
||||
"INT8 weight + FP8 activation is not supported."
|
||||
)
|
||||
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
|
||||
@ -154,6 +154,12 @@ def prepare_fp4_layer_for_marlin(
|
||||
)
|
||||
|
||||
is_nvfp4 = hasattr(layer, "weight_scale_2")
|
||||
if input_dtype is not None and input_dtype.itemsize == 1:
|
||||
if is_nvfp4:
|
||||
raise RuntimeError("NVFP4 weight + INT8/FP8 activation is not supported.")
|
||||
elif input_dtype != torch.float8_e4m3fn:
|
||||
raise RuntimeError("MXFP4 weight + INT8 activation is not supported.")
|
||||
|
||||
group_size = 16 if is_nvfp4 else 32
|
||||
|
||||
part_size_n = layer.output_size_per_partition
|
||||
@ -231,6 +237,12 @@ def prepare_moe_fp4_layer_for_marlin(
|
||||
)
|
||||
|
||||
is_nvfp4 = hasattr(layer, "w13_weight_scale_2")
|
||||
if input_dtype is not None and input_dtype.itemsize == 1:
|
||||
if is_nvfp4:
|
||||
raise RuntimeError("NVFP4 weight + INT8/FP8 activation is not supported.")
|
||||
elif input_dtype != torch.float8_e4m3fn:
|
||||
raise RuntimeError("MXFP4 weight + INT8 activation is not supported.")
|
||||
|
||||
group_size = 16 if is_nvfp4 else 32
|
||||
|
||||
e = layer.num_experts
|
||||
|
||||
@ -99,6 +99,8 @@ def prepare_fp8_layer_for_marlin(
|
||||
"be used leveraging the Marlin kernel. This may degrade "
|
||||
"performance for compute-heavy workloads."
|
||||
)
|
||||
if input_dtype is not None and input_dtype.itemsize == 1:
|
||||
raise RuntimeError("Marlin W8A8 is not supported.")
|
||||
|
||||
part_size_n = layer.output_size_per_partition
|
||||
part_size_k = layer.input_size_per_partition
|
||||
@ -206,6 +208,8 @@ def prepare_moe_fp8_layer_for_marlin(
|
||||
"be used leveraging the Marlin kernel. This may degrade "
|
||||
"performance for compute-heavy workloads."
|
||||
)
|
||||
if input_dtype is not None and input_dtype.itemsize == 1:
|
||||
raise RuntimeError("Marlin W8A8 is not supported.")
|
||||
|
||||
e = layer.num_experts
|
||||
k = layer.hidden_size
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user