From 5a3adf581e372c60d6135a535561f4d491c4d046 Mon Sep 17 00:00:00 2001 From: gnovack Date: Wed, 17 Dec 2025 19:55:00 -0800 Subject: [PATCH] fused_moe_lora PDL improvements (#30716) Signed-off-by: gnovack Co-authored-by: Cyrus Leung Co-authored-by: Jee Jee Li --- vllm/lora/ops/triton_ops/fused_moe_lora_op.py | 30 +++++++++++-------- 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index 34383cdf1767c..f04936221eea6 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -156,16 +156,22 @@ def _fused_moe_lora_kernel( + offs_bn[None, :] * stride_bn ) + if USE_GDC and IS_PRIMARY: + # GDC launch dependents hints the runtime system to launch dependent kernels. + tl.extra.cuda.gdc_launch_dependents() + # accumulator accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # GDC wait waits for ALL programs in the prior kernel to complete + # before continuing. + if USE_GDC and not IS_PRIMARY: + tl.extra.cuda.gdc_wait() + for k in range(0, grid_k): k_remaining = K - k * (BLOCK_SIZE_K * SPLIT_K) # pre-fetch lora weight b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0) - # GDC wait waits for ALL programs in the prior kernel to complete - # before continuing. - if USE_GDC and not IS_PRIMARY: - tl.extra.cuda.gdc_wait() a = tl.load( a_ptrs, mask=token_mask[:, None] & (offs_k[None, :] < k_remaining), @@ -179,9 +185,6 @@ def _fused_moe_lora_kernel( if MUL_ROUTED_WEIGHT: moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) accumulator = accumulator * moe_weight[:, None] - if USE_GDC and IS_PRIMARY: - # GDC launch dependents hints the runtime system to launch dependent kernels. - tl.extra.cuda.gdc_launch_dependents() accumulator = accumulator.to(c_ptr.dtype.element_ty) # Write back the block of the output offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) @@ -290,6 +293,7 @@ def _fused_moe_lora_shrink( def _fused_moe_lora_expand( output: torch.Tensor, # (num_tokens, top_k_num, N*len(lora_a_stacked),) a_intermediate_cache1: torch.Tensor, # (num_slices, M, top_k_num, max_lora_rank) + b_intermediate_cache1: torch.Tensor, # (num_slices, M, top_k_num, output_dim_size) lora_b_stacked: list[ torch.Tensor ], # [(max_loras, num_experts, max_lora_rank, K,),...] @@ -331,11 +335,6 @@ def _fused_moe_lora_expand( -1, a_intermediate_cache1.shape[3] ) - b_intermediate_cache1 = torch.zeros( - (num_slices, M, top_k_num, w1_output_dim_size), - dtype=output.dtype, - device=device, - ) use_gdc = supports_pdl(a_intermediate_cache1.device) expand_config = { "BLOCK_SIZE_M": block_size_m, @@ -460,6 +459,12 @@ def _fused_moe_lora( device=device, ) + b_intermediate_cache1 = torch.zeros( + (num_slices, M, top_k_num, w1_output_dim_size), + dtype=output.dtype, + device=device, + ) + _fused_moe_lora_shrink( a_intermediate_cache1, qcurr_hidden_states, @@ -506,6 +511,7 @@ def _fused_moe_lora( _fused_moe_lora_expand( output, a_intermediate_cache1, + b_intermediate_cache1, lora_b_stacked, topk_weights, sorted_token_ids,