mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 04:15:01 +08:00
Adding SplitK in fused_moe_lora kernel (#27818)
Signed-off-by: Yu Gong <yu3.gong@gmail.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
7e2729b57e
commit
29de3cdee4
@ -88,14 +88,17 @@ def _fused_moe_lora_kernel(
|
||||
grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)
|
||||
|
||||
# calculate pid_m,pid_n
|
||||
pid_sk = pid % SPLIT_K
|
||||
pid_m_n = pid // SPLIT_K
|
||||
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
group_id = pid_m_n // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
pid_m = first_pid_m + ((pid_m_n % num_pid_in_group) % group_size_m)
|
||||
pid_n = (pid_m_n % num_pid_in_group) // group_size_m
|
||||
|
||||
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr + lora_idx)
|
||||
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
||||
@ -113,7 +116,7 @@ def _fused_moe_lora_kernel(
|
||||
cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size
|
||||
|
||||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
offs_k = pid_sk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
||||
|
||||
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
||||
token_ind = stride_tl * lora_idx + offs_token_id
|
||||
@ -131,7 +134,8 @@ def _fused_moe_lora_kernel(
|
||||
cur_b_ptr
|
||||
+ lora_idx * stride_bl
|
||||
+ expert_id * stride_be
|
||||
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
+ offs_k[:, None] * stride_bk
|
||||
+ offs_bn[None, :] * stride_bn
|
||||
)
|
||||
|
||||
# accumulator
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user