diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index e08ed8fa886f7..6753a19250b3b 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -511,13 +511,19 @@ def cutlass_moe_fp8( assert quant_config is not None if quant_config.a1_scale is not None: - assert quant_config.per_act_token_quant == quant_config.a1_scale.numel() != 1 + assert quant_config.per_act_token_quant == (quant_config.a1_scale.numel() != 1) if quant_config.a2_scale is not None: - assert quant_config.per_act_token_quant == quant_config.a2_scale.numel() != 1 + assert quant_config.per_act_token_quant == (quant_config.a2_scale.numel() != 1) - assert quant_config.w1_scale is None or ( - quant_config.per_out_ch_quant == (quant_config.w1_scale.size(1) == w1_q.size(1)) - ) + if quant_config.w1_scale is not None: + if quant_config.per_out_ch_quant: + assert quant_config.w1_scale.dim() > 1 and quant_config.w1_scale.size( + 1 + ) == w1_q.size(1) + else: + assert ( + quant_config.w1_scale.dim() == 1 or quant_config.w1_scale.size(1) == 1 + ) num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(0)