Using active-loras in grid in fused_moe_lora kernel

Signed-off-by: Yu Gong <yu3.gong@gmail.com>
This commit is contained in:
Yu Gong 2025-12-17 20:02:06 +00:00
parent c016c95b45
commit 288d67d054

View File

@ -81,6 +81,7 @@ def _fused_moe_lora_kernel(
# Meta-parameters
num_slice_a: tl.constexpr,
num_slice_c: tl.constexpr,
max_loras: tl.constexpr,
top_k: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
@ -104,7 +105,7 @@ def _fused_moe_lora_kernel(
if moe_enabled == 0:
# Early exit for the no moe lora case.
return
max_loras = tl.num_programs(axis=2)
# max_loras = tl.num_programs(axis=2)
grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)
# calculate pid_m,pid_n
@ -228,6 +229,7 @@ def _fused_moe_lora_shrink(
num_warps: int,
num_stages: int,
split_k: int,
num_active_loras: int,
mul_routed_weight: bool = False,
) -> None:
w1_lora_a_stacked = lora_a_stacked[0]
@ -251,7 +253,7 @@ def _fused_moe_lora_shrink(
* triton.cdiv(EM, META["BLOCK_SIZE_M"])
* triton.cdiv(N, META["BLOCK_SIZE_N"]),
len(lora_a_stacked),
lora_a_stacked[0].shape[0],
num_active_loras,
)
_fused_moe_lora_kernel[grid](
qcurr_hidden_states,
@ -280,6 +282,7 @@ def _fused_moe_lora_shrink(
expert_ids.stride(0),
slice_a_size=qcurr_hidden_states.numel(),
slice_c_size=a_intermediate_cache1.numel() // num_slices,
max_loras=lora_a_stacked[0].shape[0],
num_slice_a=1,
num_slice_c=num_slices,
top_k=1 if mul_routed_weight else top_k_num,
@ -322,6 +325,7 @@ def _fused_moe_lora_expand(
num_warps: int,
num_stages: int,
split_k: int,
num_active_loras: int,
mul_routed_weight: bool = False,
offset: int = 0,
) -> None:
@ -351,7 +355,7 @@ def _fused_moe_lora_expand(
grid = lambda META: (
triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
len(lora_b_stacked),
lora_b_stacked[0].shape[0],
num_active_loras,
)
_fused_moe_lora_kernel[grid](
a_intermediate_cache1,
@ -382,6 +386,7 @@ def _fused_moe_lora_expand(
slice_c_size=b_intermediate_cache1.numel() // num_slices,
num_slice_a=num_slices,
num_slice_c=num_slices,
max_loras=lora_b_stacked[0].shape[0],
top_k=1,
MUL_ROUTED_WEIGHT=mul_routed_weight,
IS_PRIMARY=False,
@ -492,6 +497,7 @@ def _fused_moe_lora(
shrink_num_warps,
shrink_num_stages,
shrink_split_k,
num_active_loras,
mul_routed_weight,
)
@ -538,6 +544,7 @@ def _fused_moe_lora(
expand_num_warps,
expand_num_stages,
expand_split_k,
num_active_loras,
mul_routed_weight,
offset,
)
@ -601,6 +608,7 @@ def _fused_moe_lora_shrink_fake(
num_warps: int,
num_stages: int,
split_k: int,
num_active_loras: int,
mul_routed_weight: bool = False,
) -> None:
return
@ -634,6 +642,7 @@ def _fused_moe_lora_expand_fake(
num_warps: int,
num_stages: int,
split_k: int,
num_active_loras: int,
mul_routed_weight: bool = False,
) -> None:
return