diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py index 3d084516bf9a2..93d238a0524d8 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -11,7 +11,6 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( marlin_make_workspace_new, marlin_permute_bias, marlin_permute_scales, - marlin_quant_input, should_use_atomic_add_reduce, ) from vllm.model_executor.utils import replace_parameter @@ -63,13 +62,11 @@ def apply_fp8_marlin_linear( inputs = reshaped_x a_scales = None if input_dtype is not None and input_dtype.itemsize == 1: - if input_dtype != torch.float8_e4m3fn: - raise RuntimeError("FP8 weight + INT8 activation is not supported.") - - inputs, a_scales = marlin_quant_input(inputs, torch.float8_e4m3fn) + # inputs, a_scales = marlin_quant_input(inputs, torch.float8_e4m3fn) + raise RuntimeError("Marlin W8A8 is not supported.") output = ops.gptq_marlin_gemm( - a=reshaped_x, + a=inputs, c=None, b_q_weight=weight, b_bias=bias,