From 7c73ceb5812ace65e0d1b6ada3622b8b9f0400c0 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Sun, 21 Dec 2025 05:58:11 +0800 Subject: [PATCH] [Quantization] add marlin w4a8/w8a8 check (#31061) Signed-off-by: Jinzhen Lin --- .../layers/quantization/utils/marlin_utils.py | 12 ++++++++++++ .../layers/quantization/utils/marlin_utils_fp4.py | 12 ++++++++++++ .../layers/quantization/utils/marlin_utils_fp8.py | 4 ++++ 3 files changed, 28 insertions(+) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 66e979b505f0d..3de2b6509e460 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -594,9 +594,15 @@ def apply_awq_marlin_linear( a_scales = None if input_dtype == torch.int8: + assert quant_type == scalar_types.uint4, ( + "W8A8-INT8 is not supported by marlin kernel." + ) reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype) a_scales = a_scales * input_global_scale elif input_dtype == torch.float8_e4m3fn: + assert quant_type == scalar_types.uint4, ( + "INT8 weight + FP8 activation is not supported." + ) reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype) output = ops.gptq_marlin_gemm( @@ -649,9 +655,15 @@ def apply_rtn_marlin_linear( a_scales = None if input_dtype == torch.int8: + assert quant_type == scalar_types.uint4b8, ( + "W8A8-INT8 is not supported by marlin kernel." + ) reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype) a_scales = a_scales * input_global_scale elif input_dtype == torch.float8_e4m3fn: + assert quant_type == scalar_types.uint4b8, ( + "INT8 weight + FP8 activation is not supported." + ) reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype) output = ops.gptq_marlin_gemm( diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py index 876c724bf972d..4d0a34c3be119 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py @@ -154,6 +154,12 @@ def prepare_fp4_layer_for_marlin( ) is_nvfp4 = hasattr(layer, "weight_scale_2") + if input_dtype is not None and input_dtype.itemsize == 1: + if is_nvfp4: + raise RuntimeError("NVFP4 weight + INT8/FP8 activation is not supported.") + elif input_dtype != torch.float8_e4m3fn: + raise RuntimeError("MXFP4 weight + INT8 activation is not supported.") + group_size = 16 if is_nvfp4 else 32 part_size_n = layer.output_size_per_partition @@ -231,6 +237,12 @@ def prepare_moe_fp4_layer_for_marlin( ) is_nvfp4 = hasattr(layer, "w13_weight_scale_2") + if input_dtype is not None and input_dtype.itemsize == 1: + if is_nvfp4: + raise RuntimeError("NVFP4 weight + INT8/FP8 activation is not supported.") + elif input_dtype != torch.float8_e4m3fn: + raise RuntimeError("MXFP4 weight + INT8 activation is not supported.") + group_size = 16 if is_nvfp4 else 32 e = layer.num_experts diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py index 1fb5223b07d76..4d2f2fd71ad36 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -99,6 +99,8 @@ def prepare_fp8_layer_for_marlin( "be used leveraging the Marlin kernel. This may degrade " "performance for compute-heavy workloads." ) + if input_dtype is not None and input_dtype.itemsize == 1: + raise RuntimeError("Marlin W8A8 is not supported.") part_size_n = layer.output_size_per_partition part_size_k = layer.input_size_per_partition @@ -206,6 +208,8 @@ def prepare_moe_fp8_layer_for_marlin( "be used leveraging the Marlin kernel. This may degrade " "performance for compute-heavy workloads." ) + if input_dtype is not None and input_dtype.itemsize == 1: + raise RuntimeError("Marlin W8A8 is not supported.") e = layer.num_experts k = layer.hidden_size