From 432870829d5143840c45296b8c1f34e5f561fa85 Mon Sep 17 00:00:00 2001 From: Lucia Fang <116399278+luccafong@users.noreply.github.com> Date: Sun, 6 Jul 2025 12:08:30 +0800 Subject: [PATCH] [Bugfix] Fix missing per_act_token parameter in compressed_tensors_moe (#20509) Signed-off-by: Lu Fang --- vllm/model_executor/layers/fused_moe/cutlass_moe.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 431fb290b294b..0f41414c4896d 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -322,7 +322,7 @@ def cutlass_moe_fp8( topk_ids: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, - per_act_token: bool, + per_act_token: Optional[bool] = None, activation: str = "silu", a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, @@ -366,6 +366,9 @@ def cutlass_moe_fp8( Returns: - torch.Tensor: The fp16 output tensor after applying the MoE layer. """ + if per_act_token is None: + per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( + a2_scale.numel() != 1 if a2_scale is not None else False) per_out_ch = w1_scale.numel() != w1_q.size(0) num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(