From 288d67d054283d1b1f7346d7bcb74496da21fa04 Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Wed, 17 Dec 2025 20:02:06 +0000 Subject: [PATCH] Using active-loras in grid in fused_moe_lora kernel Signed-off-by: Yu Gong --- vllm/lora/ops/triton_ops/fused_moe_lora_op.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index f04936221eea6..e1ea092fcd428 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -81,6 +81,7 @@ def _fused_moe_lora_kernel( # Meta-parameters num_slice_a: tl.constexpr, num_slice_c: tl.constexpr, + max_loras: tl.constexpr, top_k: tl.constexpr, MUL_ROUTED_WEIGHT: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, @@ -104,7 +105,7 @@ def _fused_moe_lora_kernel( if moe_enabled == 0: # Early exit for the no moe lora case. return - max_loras = tl.num_programs(axis=2) + # max_loras = tl.num_programs(axis=2) grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K) # calculate pid_m,pid_n @@ -228,6 +229,7 @@ def _fused_moe_lora_shrink( num_warps: int, num_stages: int, split_k: int, + num_active_loras: int, mul_routed_weight: bool = False, ) -> None: w1_lora_a_stacked = lora_a_stacked[0] @@ -251,7 +253,7 @@ def _fused_moe_lora_shrink( * triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), len(lora_a_stacked), - lora_a_stacked[0].shape[0], + num_active_loras, ) _fused_moe_lora_kernel[grid]( qcurr_hidden_states, @@ -280,6 +282,7 @@ def _fused_moe_lora_shrink( expert_ids.stride(0), slice_a_size=qcurr_hidden_states.numel(), slice_c_size=a_intermediate_cache1.numel() // num_slices, + max_loras=lora_a_stacked[0].shape[0], num_slice_a=1, num_slice_c=num_slices, top_k=1 if mul_routed_weight else top_k_num, @@ -322,6 +325,7 @@ def _fused_moe_lora_expand( num_warps: int, num_stages: int, split_k: int, + num_active_loras: int, mul_routed_weight: bool = False, offset: int = 0, ) -> None: @@ -351,7 +355,7 @@ def _fused_moe_lora_expand( grid = lambda META: ( triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), len(lora_b_stacked), - lora_b_stacked[0].shape[0], + num_active_loras, ) _fused_moe_lora_kernel[grid]( a_intermediate_cache1, @@ -382,6 +386,7 @@ def _fused_moe_lora_expand( slice_c_size=b_intermediate_cache1.numel() // num_slices, num_slice_a=num_slices, num_slice_c=num_slices, + max_loras=lora_b_stacked[0].shape[0], top_k=1, MUL_ROUTED_WEIGHT=mul_routed_weight, IS_PRIMARY=False, @@ -492,6 +497,7 @@ def _fused_moe_lora( shrink_num_warps, shrink_num_stages, shrink_split_k, + num_active_loras, mul_routed_weight, ) @@ -538,6 +544,7 @@ def _fused_moe_lora( expand_num_warps, expand_num_stages, expand_split_k, + num_active_loras, mul_routed_weight, offset, ) @@ -601,6 +608,7 @@ def _fused_moe_lora_shrink_fake( num_warps: int, num_stages: int, split_k: int, + num_active_loras: int, mul_routed_weight: bool = False, ) -> None: return @@ -634,6 +642,7 @@ def _fused_moe_lora_expand_fake( num_warps: int, num_stages: int, split_k: int, + num_active_loras: int, mul_routed_weight: bool = False, ) -> None: return