[Kernel] Adding split_K implementation for fused_moe_lora (#27291)

Signed-off-by: Danielle Robinson <dmmaddix@amazon.com>
Signed-off-by: Danielle Robinson <dcmaddix@gmail.com>
Co-authored-by: Danielle Robinson <dmmaddix@amazon.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Danielle Robinson 2025-10-27 02:05:24 -07:00 committed by GitHub
parent 2d631d28c6
commit 9932ed6a83
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 35 additions and 14 deletions

View File

@ -154,6 +154,7 @@ def use_fused_moe_lora_kernel(
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"SPLIT_K": 1,
}
mul_routed_weight = False
@ -175,6 +176,7 @@ def use_fused_moe_lora_kernel(
config["BLOCK_SIZE_N"],
config["BLOCK_SIZE_K"],
config["GROUP_SIZE_M"],
config["SPLIT_K"],
mul_routed_weight,
)

View File

@ -80,11 +80,13 @@ def _fused_moe_lora_kernel(
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
SPLIT_K: tl.constexpr,
):
pid = tl.program_id(axis=0)
slice_id = tl.program_id(axis=1)
lora_idx = tl.program_id(axis=2)
max_loras = tl.num_programs(axis=2)
grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)
# calculate pid_m,pid_n
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
@ -102,7 +104,7 @@ def _fused_moe_lora_kernel(
# get the expert_id to process curr shard
ind = lora_idx * stride_el + pid_m
expert_id = tl.load(expert_ids_ptr + ind)
expert_id = tl.load(expert_ids_ptr + ind, ind < max_loras * stride_el, -1)
if expert_id == -1:
return
@ -117,7 +119,7 @@ def _fused_moe_lora_kernel(
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
token_ind = stride_tl * lora_idx + offs_token_id
offs_token = tl.load(
sorted_token_ids_ptr + token_ind, token_ind < max_loras * stride_tl, 0.0
sorted_token_ids_ptr + token_ind, token_ind < max_loras * stride_tl, 0
)
token_mask = offs_token < num_valid_tokens
@ -135,17 +137,18 @@ def _fused_moe_lora_kernel(
# accumulator
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
for k in range(0, grid_k):
k_remaining = K - k * (BLOCK_SIZE_K * SPLIT_K)
a = tl.load(
a_ptrs,
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
mask=token_mask[:, None] & (offs_k[None, :] < k_remaining),
other=0.0,
)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak
b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
@ -156,7 +159,10 @@ def _fused_moe_lora_kernel(
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = cur_c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
if SPLIT_K == 1:
tl.store(c_ptrs, accumulator, mask=c_mask)
else:
tl.atomic_add(c_ptrs, accumulator, mask=c_mask, sem="relaxed")
@torch.inference_mode()
@ -179,6 +185,7 @@ def _fused_moe_lora(
block_size_n: int,
block_size_k: int,
group_size_m: int,
split_k: int,
mul_routed_weight: bool = False,
) -> None:
assert len(lora_a_stacked) == len(lora_b_stacked) > 0
@ -206,6 +213,7 @@ def _fused_moe_lora(
"BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k,
"GROUP_SIZE_M": group_size_m,
"SPLIT_K": split_k,
}
w1_lora_a_stacked = lora_a_stacked[0]
@ -237,7 +245,9 @@ def _fused_moe_lora(
b_ptr = _get_ptr(lora_a_stacked, device)
grid = lambda META: (
triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
split_k
* triton.cdiv(EM, META["BLOCK_SIZE_M"])
* triton.cdiv(N, META["BLOCK_SIZE_N"]),
len(lora_a_stacked),
lora_a_stacked[0].shape[0],
)
@ -286,6 +296,8 @@ def _fused_moe_lora(
-1, a_intermediate_cache1.shape[3]
)
# Set split_k = 1 for expand calls
config["SPLIT_K"] = 1
grid = lambda META: (
triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
len(lora_b_stacked),

View File

@ -385,5 +385,6 @@ class PunicaWrapperGPU(PunicaWrapperBase):
config["BLOCK_SIZE_N"],
config["BLOCK_SIZE_K"],
config["GROUP_SIZE_M"],
config.get("SPLIT_K", 1),
mul_routed_weight,
)

View File

@ -121,6 +121,7 @@ def fused_moe_kernel_gptq_awq(
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
SPLIT_K: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr,
compute_type: tl.constexpr,
@ -356,6 +357,7 @@ def fused_moe_kernel(
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
SPLIT_K: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr,
compute_type: tl.constexpr,
@ -646,7 +648,6 @@ def invoke_fused_moe_kernel(
bit,
)
return
fused_moe_kernel_gptq_awq[grid](
A,
B,
@ -686,6 +687,7 @@ def invoke_fused_moe_kernel(
)
else:
config = config.copy()
config["SPLIT_K"] = 1
BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K")
if block_shape is not None:
BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1]))
@ -983,6 +985,7 @@ def get_default_config(
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
"SPLIT_K": 1,
}
return config
@ -996,6 +999,7 @@ def get_default_config(
"BLOCK_SIZE_N": block_shape[0],
"BLOCK_SIZE_K": block_shape[1],
"GROUP_SIZE_M": 32,
"SPLIT_K": 1,
"num_warps": 4,
"num_stages": 3 if not current_platform.is_rocm() else 2,
}
@ -1006,19 +1010,20 @@ def get_default_config(
bit = 4 if dtype == "int4_w4a16" else 8
use_moe_wna16_cuda = should_moe_wna16_use_cuda(M * topk, block_shape[1], E, bit)
if use_moe_wna16_cuda:
config = {"BLOCK_SIZE_M": min(16, M)}
config = {"BLOCK_SIZE_M": min(16, M), "SPLIT_K": 1}
elif M <= 20:
config = {"BLOCK_SIZE_M": 16, "GROUP_SIZE_M": 1}
config = {"BLOCK_SIZE_M": 16, "GROUP_SIZE_M": 1, "SPLIT_K": 1}
elif M <= 40:
config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1}
config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1, "SPLIT_K": 1}
else:
config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1}
config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1, "SPLIT_K": 1}
elif M <= E:
config = {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"SPLIT_K": 1,
}
else:
config = {
@ -1026,6 +1031,7 @@ def get_default_config(
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
"SPLIT_K": 1,
}
return config