mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 06:34:58 +08:00
[Kernel] Adding split_K implementation for fused_moe_lora (#27291)
Signed-off-by: Danielle Robinson <dmmaddix@amazon.com> Signed-off-by: Danielle Robinson <dcmaddix@gmail.com> Co-authored-by: Danielle Robinson <dmmaddix@amazon.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
2d631d28c6
commit
9932ed6a83
@ -154,6 +154,7 @@ def use_fused_moe_lora_kernel(
|
|||||||
"BLOCK_SIZE_N": 32,
|
"BLOCK_SIZE_N": 32,
|
||||||
"BLOCK_SIZE_K": 64,
|
"BLOCK_SIZE_K": 64,
|
||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 1,
|
||||||
|
"SPLIT_K": 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
mul_routed_weight = False
|
mul_routed_weight = False
|
||||||
@ -175,6 +176,7 @@ def use_fused_moe_lora_kernel(
|
|||||||
config["BLOCK_SIZE_N"],
|
config["BLOCK_SIZE_N"],
|
||||||
config["BLOCK_SIZE_K"],
|
config["BLOCK_SIZE_K"],
|
||||||
config["GROUP_SIZE_M"],
|
config["GROUP_SIZE_M"],
|
||||||
|
config["SPLIT_K"],
|
||||||
mul_routed_weight,
|
mul_routed_weight,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -80,11 +80,13 @@ def _fused_moe_lora_kernel(
|
|||||||
BLOCK_SIZE_N: tl.constexpr,
|
BLOCK_SIZE_N: tl.constexpr,
|
||||||
BLOCK_SIZE_K: tl.constexpr,
|
BLOCK_SIZE_K: tl.constexpr,
|
||||||
GROUP_SIZE_M: tl.constexpr,
|
GROUP_SIZE_M: tl.constexpr,
|
||||||
|
SPLIT_K: tl.constexpr,
|
||||||
):
|
):
|
||||||
pid = tl.program_id(axis=0)
|
pid = tl.program_id(axis=0)
|
||||||
slice_id = tl.program_id(axis=1)
|
slice_id = tl.program_id(axis=1)
|
||||||
lora_idx = tl.program_id(axis=2)
|
lora_idx = tl.program_id(axis=2)
|
||||||
max_loras = tl.num_programs(axis=2)
|
max_loras = tl.num_programs(axis=2)
|
||||||
|
grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)
|
||||||
|
|
||||||
# calculate pid_m,pid_n
|
# calculate pid_m,pid_n
|
||||||
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
|
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
|
||||||
@ -102,7 +104,7 @@ def _fused_moe_lora_kernel(
|
|||||||
|
|
||||||
# get the expert_id to process curr shard
|
# get the expert_id to process curr shard
|
||||||
ind = lora_idx * stride_el + pid_m
|
ind = lora_idx * stride_el + pid_m
|
||||||
expert_id = tl.load(expert_ids_ptr + ind)
|
expert_id = tl.load(expert_ids_ptr + ind, ind < max_loras * stride_el, -1)
|
||||||
if expert_id == -1:
|
if expert_id == -1:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -117,7 +119,7 @@ def _fused_moe_lora_kernel(
|
|||||||
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
||||||
token_ind = stride_tl * lora_idx + offs_token_id
|
token_ind = stride_tl * lora_idx + offs_token_id
|
||||||
offs_token = tl.load(
|
offs_token = tl.load(
|
||||||
sorted_token_ids_ptr + token_ind, token_ind < max_loras * stride_tl, 0.0
|
sorted_token_ids_ptr + token_ind, token_ind < max_loras * stride_tl, 0
|
||||||
)
|
)
|
||||||
token_mask = offs_token < num_valid_tokens
|
token_mask = offs_token < num_valid_tokens
|
||||||
|
|
||||||
@ -135,17 +137,18 @@ def _fused_moe_lora_kernel(
|
|||||||
|
|
||||||
# accumulator
|
# accumulator
|
||||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
for k in range(0, grid_k):
|
||||||
|
k_remaining = K - k * (BLOCK_SIZE_K * SPLIT_K)
|
||||||
a = tl.load(
|
a = tl.load(
|
||||||
a_ptrs,
|
a_ptrs,
|
||||||
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
mask=token_mask[:, None] & (offs_k[None, :] < k_remaining),
|
||||||
other=0.0,
|
other=0.0,
|
||||||
)
|
)
|
||||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
|
||||||
accumulator += tl.dot(a, b)
|
accumulator += tl.dot(a, b)
|
||||||
# Advance the ptrs to the next K block.
|
# Advance the ptrs to the next K block.
|
||||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak
|
||||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk
|
||||||
|
|
||||||
if MUL_ROUTED_WEIGHT:
|
if MUL_ROUTED_WEIGHT:
|
||||||
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
||||||
@ -156,7 +159,10 @@ def _fused_moe_lora_kernel(
|
|||||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
c_ptrs = cur_c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
c_ptrs = cur_c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
||||||
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
||||||
|
if SPLIT_K == 1:
|
||||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||||
|
else:
|
||||||
|
tl.atomic_add(c_ptrs, accumulator, mask=c_mask, sem="relaxed")
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
@ -179,6 +185,7 @@ def _fused_moe_lora(
|
|||||||
block_size_n: int,
|
block_size_n: int,
|
||||||
block_size_k: int,
|
block_size_k: int,
|
||||||
group_size_m: int,
|
group_size_m: int,
|
||||||
|
split_k: int,
|
||||||
mul_routed_weight: bool = False,
|
mul_routed_weight: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert len(lora_a_stacked) == len(lora_b_stacked) > 0
|
assert len(lora_a_stacked) == len(lora_b_stacked) > 0
|
||||||
@ -206,6 +213,7 @@ def _fused_moe_lora(
|
|||||||
"BLOCK_SIZE_N": block_size_n,
|
"BLOCK_SIZE_N": block_size_n,
|
||||||
"BLOCK_SIZE_K": block_size_k,
|
"BLOCK_SIZE_K": block_size_k,
|
||||||
"GROUP_SIZE_M": group_size_m,
|
"GROUP_SIZE_M": group_size_m,
|
||||||
|
"SPLIT_K": split_k,
|
||||||
}
|
}
|
||||||
|
|
||||||
w1_lora_a_stacked = lora_a_stacked[0]
|
w1_lora_a_stacked = lora_a_stacked[0]
|
||||||
@ -237,7 +245,9 @@ def _fused_moe_lora(
|
|||||||
b_ptr = _get_ptr(lora_a_stacked, device)
|
b_ptr = _get_ptr(lora_a_stacked, device)
|
||||||
|
|
||||||
grid = lambda META: (
|
grid = lambda META: (
|
||||||
triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
split_k
|
||||||
|
* triton.cdiv(EM, META["BLOCK_SIZE_M"])
|
||||||
|
* triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||||
len(lora_a_stacked),
|
len(lora_a_stacked),
|
||||||
lora_a_stacked[0].shape[0],
|
lora_a_stacked[0].shape[0],
|
||||||
)
|
)
|
||||||
@ -286,6 +296,8 @@ def _fused_moe_lora(
|
|||||||
-1, a_intermediate_cache1.shape[3]
|
-1, a_intermediate_cache1.shape[3]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Set split_k = 1 for expand calls
|
||||||
|
config["SPLIT_K"] = 1
|
||||||
grid = lambda META: (
|
grid = lambda META: (
|
||||||
triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||||
len(lora_b_stacked),
|
len(lora_b_stacked),
|
||||||
|
|||||||
@ -385,5 +385,6 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
|||||||
config["BLOCK_SIZE_N"],
|
config["BLOCK_SIZE_N"],
|
||||||
config["BLOCK_SIZE_K"],
|
config["BLOCK_SIZE_K"],
|
||||||
config["GROUP_SIZE_M"],
|
config["GROUP_SIZE_M"],
|
||||||
|
config.get("SPLIT_K", 1),
|
||||||
mul_routed_weight,
|
mul_routed_weight,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -121,6 +121,7 @@ def fused_moe_kernel_gptq_awq(
|
|||||||
BLOCK_SIZE_N: tl.constexpr,
|
BLOCK_SIZE_N: tl.constexpr,
|
||||||
BLOCK_SIZE_K: tl.constexpr,
|
BLOCK_SIZE_K: tl.constexpr,
|
||||||
GROUP_SIZE_M: tl.constexpr,
|
GROUP_SIZE_M: tl.constexpr,
|
||||||
|
SPLIT_K: tl.constexpr,
|
||||||
MUL_ROUTED_WEIGHT: tl.constexpr,
|
MUL_ROUTED_WEIGHT: tl.constexpr,
|
||||||
top_k: tl.constexpr,
|
top_k: tl.constexpr,
|
||||||
compute_type: tl.constexpr,
|
compute_type: tl.constexpr,
|
||||||
@ -356,6 +357,7 @@ def fused_moe_kernel(
|
|||||||
BLOCK_SIZE_N: tl.constexpr,
|
BLOCK_SIZE_N: tl.constexpr,
|
||||||
BLOCK_SIZE_K: tl.constexpr,
|
BLOCK_SIZE_K: tl.constexpr,
|
||||||
GROUP_SIZE_M: tl.constexpr,
|
GROUP_SIZE_M: tl.constexpr,
|
||||||
|
SPLIT_K: tl.constexpr,
|
||||||
MUL_ROUTED_WEIGHT: tl.constexpr,
|
MUL_ROUTED_WEIGHT: tl.constexpr,
|
||||||
top_k: tl.constexpr,
|
top_k: tl.constexpr,
|
||||||
compute_type: tl.constexpr,
|
compute_type: tl.constexpr,
|
||||||
@ -646,7 +648,6 @@ def invoke_fused_moe_kernel(
|
|||||||
bit,
|
bit,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
fused_moe_kernel_gptq_awq[grid](
|
fused_moe_kernel_gptq_awq[grid](
|
||||||
A,
|
A,
|
||||||
B,
|
B,
|
||||||
@ -686,6 +687,7 @@ def invoke_fused_moe_kernel(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
config = config.copy()
|
config = config.copy()
|
||||||
|
config["SPLIT_K"] = 1
|
||||||
BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K")
|
BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K")
|
||||||
if block_shape is not None:
|
if block_shape is not None:
|
||||||
BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1]))
|
BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1]))
|
||||||
@ -983,6 +985,7 @@ def get_default_config(
|
|||||||
"BLOCK_SIZE_N": 64,
|
"BLOCK_SIZE_N": 64,
|
||||||
"BLOCK_SIZE_K": 32,
|
"BLOCK_SIZE_K": 32,
|
||||||
"GROUP_SIZE_M": 8,
|
"GROUP_SIZE_M": 8,
|
||||||
|
"SPLIT_K": 1,
|
||||||
}
|
}
|
||||||
return config
|
return config
|
||||||
|
|
||||||
@ -996,6 +999,7 @@ def get_default_config(
|
|||||||
"BLOCK_SIZE_N": block_shape[0],
|
"BLOCK_SIZE_N": block_shape[0],
|
||||||
"BLOCK_SIZE_K": block_shape[1],
|
"BLOCK_SIZE_K": block_shape[1],
|
||||||
"GROUP_SIZE_M": 32,
|
"GROUP_SIZE_M": 32,
|
||||||
|
"SPLIT_K": 1,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 3 if not current_platform.is_rocm() else 2,
|
"num_stages": 3 if not current_platform.is_rocm() else 2,
|
||||||
}
|
}
|
||||||
@ -1006,19 +1010,20 @@ def get_default_config(
|
|||||||
bit = 4 if dtype == "int4_w4a16" else 8
|
bit = 4 if dtype == "int4_w4a16" else 8
|
||||||
use_moe_wna16_cuda = should_moe_wna16_use_cuda(M * topk, block_shape[1], E, bit)
|
use_moe_wna16_cuda = should_moe_wna16_use_cuda(M * topk, block_shape[1], E, bit)
|
||||||
if use_moe_wna16_cuda:
|
if use_moe_wna16_cuda:
|
||||||
config = {"BLOCK_SIZE_M": min(16, M)}
|
config = {"BLOCK_SIZE_M": min(16, M), "SPLIT_K": 1}
|
||||||
elif M <= 20:
|
elif M <= 20:
|
||||||
config = {"BLOCK_SIZE_M": 16, "GROUP_SIZE_M": 1}
|
config = {"BLOCK_SIZE_M": 16, "GROUP_SIZE_M": 1, "SPLIT_K": 1}
|
||||||
elif M <= 40:
|
elif M <= 40:
|
||||||
config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1}
|
config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1, "SPLIT_K": 1}
|
||||||
else:
|
else:
|
||||||
config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1}
|
config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1, "SPLIT_K": 1}
|
||||||
elif M <= E:
|
elif M <= E:
|
||||||
config = {
|
config = {
|
||||||
"BLOCK_SIZE_M": 16,
|
"BLOCK_SIZE_M": 16,
|
||||||
"BLOCK_SIZE_N": 32,
|
"BLOCK_SIZE_N": 32,
|
||||||
"BLOCK_SIZE_K": 64,
|
"BLOCK_SIZE_K": 64,
|
||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 1,
|
||||||
|
"SPLIT_K": 1,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
config = {
|
config = {
|
||||||
@ -1026,6 +1031,7 @@ def get_default_config(
|
|||||||
"BLOCK_SIZE_N": 64,
|
"BLOCK_SIZE_N": 64,
|
||||||
"BLOCK_SIZE_K": 32,
|
"BLOCK_SIZE_K": 32,
|
||||||
"GROUP_SIZE_M": 8,
|
"GROUP_SIZE_M": 8,
|
||||||
|
"SPLIT_K": 1,
|
||||||
}
|
}
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user