From e33ee23ee3cde5aa69912d9d03aa31421851662c Mon Sep 17 00:00:00 2001 From: vllmellm Date: Sat, 18 Oct 2025 02:51:10 +0800 Subject: [PATCH] [Bugfix] [AITER] [ROCm] Fix Quark MoE Quant Config and AITER Fused MoE quant type logic (#27029) Signed-off-by: vllmellm --- .../layers/fused_moe/rocm_aiter_fused_moe.py | 2 ++ .../layers/quantization/quark/quark_moe.py | 3 ++- .../layers/quantization/utils/w8a8_utils.py | 12 ++++++++++-- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 820c0af71cbc7..350713698aeef 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -492,6 +492,8 @@ def rocm_aiter_fused_experts( assert quant_config.w1_scale is not None assert quant_config.w2_scale is not None quant_method = QuantMethod.BLOCK_128x128.value + elif quant_config.use_fp8_w8a8 and quant_config.per_out_ch_quant: + quant_method = QuantMethod.PER_TOKEN.value elif quant_config.use_fp8_w8a8: # Currently only per tensor quantization method is enabled. quant_method = QuantMethod.PER_TENSOR.value diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 5cab6e205ca08..a8f4b1b0db68d 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -342,7 +342,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): w2_scale=layer.w2_weight_scale, a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, - per_act_token_quant=self.weight_qscheme == "per_channel", + per_act_token_quant=self.input_qscheme == "per_channel", + per_out_ch_quant=self.weight_qscheme == "per_channel", ) def apply( diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 4fda4d76a9808..17da125d5eb77 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -464,8 +464,16 @@ class Fp8LinearOp: else: qinput, x_scale = input_2d, input_scale - per_tensor_weights = weight_scale.numel() == 1 - per_tensor_activations = x_scale.numel() == 1 + # Must have dim() conditions + # In per-token quant scenario, when the number of token is 1, + # the scale will only have 1 elements. + # Without checking the dim(), + # we cannot distingushes between per-tensor and per-token quant. + # Example: + # When the number of token is 1, per-token scale is [[1]] + # When per-tensor scale is [1] or (). + per_tensor_weights = (weight_scale.numel() == 1) and weight_scale.dim() < 2 + per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2 # TODO(luka) do this dispatch during init (after ScaledMM refactor) w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm(