mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-03 06:15:16 +08:00
fused_moe_lora PDL improvements (#30716)
Signed-off-by: gnovack <gnovack@amazon.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
6fe5887652
commit
5a3adf581e
@ -156,16 +156,22 @@ def _fused_moe_lora_kernel(
|
||||
+ offs_bn[None, :] * stride_bn
|
||||
)
|
||||
|
||||
if USE_GDC and IS_PRIMARY:
|
||||
# GDC launch dependents hints the runtime system to launch dependent kernels.
|
||||
tl.extra.cuda.gdc_launch_dependents()
|
||||
|
||||
# accumulator
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
# GDC wait waits for ALL programs in the prior kernel to complete
|
||||
# before continuing.
|
||||
if USE_GDC and not IS_PRIMARY:
|
||||
tl.extra.cuda.gdc_wait()
|
||||
|
||||
for k in range(0, grid_k):
|
||||
k_remaining = K - k * (BLOCK_SIZE_K * SPLIT_K)
|
||||
# pre-fetch lora weight
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
|
||||
# GDC wait waits for ALL programs in the prior kernel to complete
|
||||
# before continuing.
|
||||
if USE_GDC and not IS_PRIMARY:
|
||||
tl.extra.cuda.gdc_wait()
|
||||
a = tl.load(
|
||||
a_ptrs,
|
||||
mask=token_mask[:, None] & (offs_k[None, :] < k_remaining),
|
||||
@ -179,9 +185,6 @@ def _fused_moe_lora_kernel(
|
||||
if MUL_ROUTED_WEIGHT:
|
||||
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
||||
accumulator = accumulator * moe_weight[:, None]
|
||||
if USE_GDC and IS_PRIMARY:
|
||||
# GDC launch dependents hints the runtime system to launch dependent kernels.
|
||||
tl.extra.cuda.gdc_launch_dependents()
|
||||
accumulator = accumulator.to(c_ptr.dtype.element_ty)
|
||||
# Write back the block of the output
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
@ -290,6 +293,7 @@ def _fused_moe_lora_shrink(
|
||||
def _fused_moe_lora_expand(
|
||||
output: torch.Tensor, # (num_tokens, top_k_num, N*len(lora_a_stacked),)
|
||||
a_intermediate_cache1: torch.Tensor, # (num_slices, M, top_k_num, max_lora_rank)
|
||||
b_intermediate_cache1: torch.Tensor, # (num_slices, M, top_k_num, output_dim_size)
|
||||
lora_b_stacked: list[
|
||||
torch.Tensor
|
||||
], # [(max_loras, num_experts, max_lora_rank, K,),...]
|
||||
@ -331,11 +335,6 @@ def _fused_moe_lora_expand(
|
||||
-1, a_intermediate_cache1.shape[3]
|
||||
)
|
||||
|
||||
b_intermediate_cache1 = torch.zeros(
|
||||
(num_slices, M, top_k_num, w1_output_dim_size),
|
||||
dtype=output.dtype,
|
||||
device=device,
|
||||
)
|
||||
use_gdc = supports_pdl(a_intermediate_cache1.device)
|
||||
expand_config = {
|
||||
"BLOCK_SIZE_M": block_size_m,
|
||||
@ -460,6 +459,12 @@ def _fused_moe_lora(
|
||||
device=device,
|
||||
)
|
||||
|
||||
b_intermediate_cache1 = torch.zeros(
|
||||
(num_slices, M, top_k_num, w1_output_dim_size),
|
||||
dtype=output.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
_fused_moe_lora_shrink(
|
||||
a_intermediate_cache1,
|
||||
qcurr_hidden_states,
|
||||
@ -506,6 +511,7 @@ def _fused_moe_lora(
|
||||
_fused_moe_lora_expand(
|
||||
output,
|
||||
a_intermediate_cache1,
|
||||
b_intermediate_cache1,
|
||||
lora_b_stacked,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user