mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:45:00 +08:00
[Kernel] Adding split_K implementation for fused_moe_lora (#27291)
Signed-off-by: Danielle Robinson <dmmaddix@amazon.com> Signed-off-by: Danielle Robinson <dcmaddix@gmail.com> Co-authored-by: Danielle Robinson <dmmaddix@amazon.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
2d631d28c6
commit
9932ed6a83
@ -154,6 +154,7 @@ def use_fused_moe_lora_kernel(
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"SPLIT_K": 1,
|
||||
}
|
||||
|
||||
mul_routed_weight = False
|
||||
@ -175,6 +176,7 @@ def use_fused_moe_lora_kernel(
|
||||
config["BLOCK_SIZE_N"],
|
||||
config["BLOCK_SIZE_K"],
|
||||
config["GROUP_SIZE_M"],
|
||||
config["SPLIT_K"],
|
||||
mul_routed_weight,
|
||||
)
|
||||
|
||||
|
||||
@ -80,11 +80,13 @@ def _fused_moe_lora_kernel(
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
SPLIT_K: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
slice_id = tl.program_id(axis=1)
|
||||
lora_idx = tl.program_id(axis=2)
|
||||
max_loras = tl.num_programs(axis=2)
|
||||
grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)
|
||||
|
||||
# calculate pid_m,pid_n
|
||||
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
|
||||
@ -102,7 +104,7 @@ def _fused_moe_lora_kernel(
|
||||
|
||||
# get the expert_id to process curr shard
|
||||
ind = lora_idx * stride_el + pid_m
|
||||
expert_id = tl.load(expert_ids_ptr + ind)
|
||||
expert_id = tl.load(expert_ids_ptr + ind, ind < max_loras * stride_el, -1)
|
||||
if expert_id == -1:
|
||||
return
|
||||
|
||||
@ -117,7 +119,7 @@ def _fused_moe_lora_kernel(
|
||||
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
||||
token_ind = stride_tl * lora_idx + offs_token_id
|
||||
offs_token = tl.load(
|
||||
sorted_token_ids_ptr + token_ind, token_ind < max_loras * stride_tl, 0.0
|
||||
sorted_token_ids_ptr + token_ind, token_ind < max_loras * stride_tl, 0
|
||||
)
|
||||
token_mask = offs_token < num_valid_tokens
|
||||
|
||||
@ -135,17 +137,18 @@ def _fused_moe_lora_kernel(
|
||||
|
||||
# accumulator
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
for k in range(0, grid_k):
|
||||
k_remaining = K - k * (BLOCK_SIZE_K * SPLIT_K)
|
||||
a = tl.load(
|
||||
a_ptrs,
|
||||
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
||||
mask=token_mask[:, None] & (offs_k[None, :] < k_remaining),
|
||||
other=0.0,
|
||||
)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
|
||||
accumulator += tl.dot(a, b)
|
||||
# Advance the ptrs to the next K block.
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk
|
||||
|
||||
if MUL_ROUTED_WEIGHT:
|
||||
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
||||
@ -156,7 +159,10 @@ def _fused_moe_lora_kernel(
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = cur_c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
||||
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
if SPLIT_K == 1:
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
else:
|
||||
tl.atomic_add(c_ptrs, accumulator, mask=c_mask, sem="relaxed")
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@ -179,6 +185,7 @@ def _fused_moe_lora(
|
||||
block_size_n: int,
|
||||
block_size_k: int,
|
||||
group_size_m: int,
|
||||
split_k: int,
|
||||
mul_routed_weight: bool = False,
|
||||
) -> None:
|
||||
assert len(lora_a_stacked) == len(lora_b_stacked) > 0
|
||||
@ -206,6 +213,7 @@ def _fused_moe_lora(
|
||||
"BLOCK_SIZE_N": block_size_n,
|
||||
"BLOCK_SIZE_K": block_size_k,
|
||||
"GROUP_SIZE_M": group_size_m,
|
||||
"SPLIT_K": split_k,
|
||||
}
|
||||
|
||||
w1_lora_a_stacked = lora_a_stacked[0]
|
||||
@ -237,7 +245,9 @@ def _fused_moe_lora(
|
||||
b_ptr = _get_ptr(lora_a_stacked, device)
|
||||
|
||||
grid = lambda META: (
|
||||
triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||
split_k
|
||||
* triton.cdiv(EM, META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||
len(lora_a_stacked),
|
||||
lora_a_stacked[0].shape[0],
|
||||
)
|
||||
@ -286,6 +296,8 @@ def _fused_moe_lora(
|
||||
-1, a_intermediate_cache1.shape[3]
|
||||
)
|
||||
|
||||
# Set split_k = 1 for expand calls
|
||||
config["SPLIT_K"] = 1
|
||||
grid = lambda META: (
|
||||
triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||
len(lora_b_stacked),
|
||||
|
||||
@ -385,5 +385,6 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
config["BLOCK_SIZE_N"],
|
||||
config["BLOCK_SIZE_K"],
|
||||
config["GROUP_SIZE_M"],
|
||||
config.get("SPLIT_K", 1),
|
||||
mul_routed_weight,
|
||||
)
|
||||
|
||||
@ -121,6 +121,7 @@ def fused_moe_kernel_gptq_awq(
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
SPLIT_K: tl.constexpr,
|
||||
MUL_ROUTED_WEIGHT: tl.constexpr,
|
||||
top_k: tl.constexpr,
|
||||
compute_type: tl.constexpr,
|
||||
@ -356,6 +357,7 @@ def fused_moe_kernel(
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
SPLIT_K: tl.constexpr,
|
||||
MUL_ROUTED_WEIGHT: tl.constexpr,
|
||||
top_k: tl.constexpr,
|
||||
compute_type: tl.constexpr,
|
||||
@ -646,7 +648,6 @@ def invoke_fused_moe_kernel(
|
||||
bit,
|
||||
)
|
||||
return
|
||||
|
||||
fused_moe_kernel_gptq_awq[grid](
|
||||
A,
|
||||
B,
|
||||
@ -686,6 +687,7 @@ def invoke_fused_moe_kernel(
|
||||
)
|
||||
else:
|
||||
config = config.copy()
|
||||
config["SPLIT_K"] = 1
|
||||
BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K")
|
||||
if block_shape is not None:
|
||||
BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1]))
|
||||
@ -983,6 +985,7 @@ def get_default_config(
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"SPLIT_K": 1,
|
||||
}
|
||||
return config
|
||||
|
||||
@ -996,6 +999,7 @@ def get_default_config(
|
||||
"BLOCK_SIZE_N": block_shape[0],
|
||||
"BLOCK_SIZE_K": block_shape[1],
|
||||
"GROUP_SIZE_M": 32,
|
||||
"SPLIT_K": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3 if not current_platform.is_rocm() else 2,
|
||||
}
|
||||
@ -1006,19 +1010,20 @@ def get_default_config(
|
||||
bit = 4 if dtype == "int4_w4a16" else 8
|
||||
use_moe_wna16_cuda = should_moe_wna16_use_cuda(M * topk, block_shape[1], E, bit)
|
||||
if use_moe_wna16_cuda:
|
||||
config = {"BLOCK_SIZE_M": min(16, M)}
|
||||
config = {"BLOCK_SIZE_M": min(16, M), "SPLIT_K": 1}
|
||||
elif M <= 20:
|
||||
config = {"BLOCK_SIZE_M": 16, "GROUP_SIZE_M": 1}
|
||||
config = {"BLOCK_SIZE_M": 16, "GROUP_SIZE_M": 1, "SPLIT_K": 1}
|
||||
elif M <= 40:
|
||||
config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1}
|
||||
config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1, "SPLIT_K": 1}
|
||||
else:
|
||||
config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1}
|
||||
config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1, "SPLIT_K": 1}
|
||||
elif M <= E:
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"SPLIT_K": 1,
|
||||
}
|
||||
else:
|
||||
config = {
|
||||
@ -1026,6 +1031,7 @@ def get_default_config(
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"SPLIT_K": 1,
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user