[Quantization] add marlin w4a8/w8a8 check (#31061)

Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
This commit is contained in:
Jinzhen Lin 2025-12-21 05:58:11 +08:00 committed by GitHub
parent ae0770fa6b
commit 7c73ceb581
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 28 additions and 0 deletions

View File

@ -594,9 +594,15 @@ def apply_awq_marlin_linear(
a_scales = None
if input_dtype == torch.int8:
assert quant_type == scalar_types.uint4, (
"W8A8-INT8 is not supported by marlin kernel."
)
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
a_scales = a_scales * input_global_scale
elif input_dtype == torch.float8_e4m3fn:
assert quant_type == scalar_types.uint4, (
"INT8 weight + FP8 activation is not supported."
)
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
output = ops.gptq_marlin_gemm(
@ -649,9 +655,15 @@ def apply_rtn_marlin_linear(
a_scales = None
if input_dtype == torch.int8:
assert quant_type == scalar_types.uint4b8, (
"W8A8-INT8 is not supported by marlin kernel."
)
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
a_scales = a_scales * input_global_scale
elif input_dtype == torch.float8_e4m3fn:
assert quant_type == scalar_types.uint4b8, (
"INT8 weight + FP8 activation is not supported."
)
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
output = ops.gptq_marlin_gemm(

View File

@ -154,6 +154,12 @@ def prepare_fp4_layer_for_marlin(
)
is_nvfp4 = hasattr(layer, "weight_scale_2")
if input_dtype is not None and input_dtype.itemsize == 1:
if is_nvfp4:
raise RuntimeError("NVFP4 weight + INT8/FP8 activation is not supported.")
elif input_dtype != torch.float8_e4m3fn:
raise RuntimeError("MXFP4 weight + INT8 activation is not supported.")
group_size = 16 if is_nvfp4 else 32
part_size_n = layer.output_size_per_partition
@ -231,6 +237,12 @@ def prepare_moe_fp4_layer_for_marlin(
)
is_nvfp4 = hasattr(layer, "w13_weight_scale_2")
if input_dtype is not None and input_dtype.itemsize == 1:
if is_nvfp4:
raise RuntimeError("NVFP4 weight + INT8/FP8 activation is not supported.")
elif input_dtype != torch.float8_e4m3fn:
raise RuntimeError("MXFP4 weight + INT8 activation is not supported.")
group_size = 16 if is_nvfp4 else 32
e = layer.num_experts

View File

@ -99,6 +99,8 @@ def prepare_fp8_layer_for_marlin(
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads."
)
if input_dtype is not None and input_dtype.itemsize == 1:
raise RuntimeError("Marlin W8A8 is not supported.")
part_size_n = layer.output_size_per_partition
part_size_k = layer.input_size_per_partition
@ -206,6 +208,8 @@ def prepare_moe_fp8_layer_for_marlin(
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads."
)
if input_dtype is not None and input_dtype.itemsize == 1:
raise RuntimeError("Marlin W8A8 is not supported.")
e = layer.num_experts
k = layer.hidden_size