From 29de3cdee4dd7f805931b459398b15c3b5f7057c Mon Sep 17 00:00:00 2001 From: yugong333 Date: Fri, 31 Oct 2025 21:55:46 -0700 Subject: [PATCH] Adding SplitK in fused_moe_lora kernel (#27818) Signed-off-by: Yu Gong Co-authored-by: Jee Jee Li --- vllm/lora/ops/triton_ops/fused_moe_lora_op.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index e681f3882908..15031f5e2f9e 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -88,14 +88,17 @@ def _fused_moe_lora_kernel( grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K) # calculate pid_m,pid_n + pid_sk = pid % SPLIT_K + pid_m_n = pid // SPLIT_K num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group + group_id = pid_m_n // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m + pid_m = first_pid_m + ((pid_m_n % num_pid_in_group) % group_size_m) + pid_n = (pid_m_n % num_pid_in_group) // group_size_m num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr + lora_idx) if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: @@ -113,7 +116,7 @@ def _fused_moe_lora_kernel( cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N - offs_k = tl.arange(0, BLOCK_SIZE_K) + offs_k = pid_sk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) token_ind = stride_tl * lora_idx + offs_token_id @@ -131,7 +134,8 @@ def _fused_moe_lora_kernel( cur_b_ptr + lora_idx * stride_bl + expert_id * stride_be - + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn ) # accumulator