From dd38ba3a2682d6f73b02bc983a5b0157ed3e5498 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 4 Dec 2025 12:51:15 +0800 Subject: [PATCH] [Bugfix] Fix adapter_enabled IMA (#29977) Signed-off-by: Jee Jee Li --- vllm/lora/ops/triton_ops/fused_moe_lora_op.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index 413ee8ecbbf9..34383cdf1767 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -96,10 +96,14 @@ def _fused_moe_lora_kernel( slice_id = tl.program_id(axis=1) lora_idx = tl.program_id(axis=2) lora_id = tl.load(lora_ids + lora_idx) - moe_enabled = tl.load(adapter_enabled + lora_id) - if lora_id == -1 or moe_enabled == 0: + + if lora_id == -1: # Early exit for the no-lora case. return + moe_enabled = tl.load(adapter_enabled + lora_id) + if moe_enabled == 0: + # Early exit for the no moe lora case. + return max_loras = tl.num_programs(axis=2) grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)