mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-01 01:47:04 +08:00
Using active-loras in grid in fused_moe_lora kernel
Signed-off-by: Yu Gong <yu3.gong@gmail.com>
This commit is contained in:
parent
c016c95b45
commit
288d67d054
@ -81,6 +81,7 @@ def _fused_moe_lora_kernel(
|
||||
# Meta-parameters
|
||||
num_slice_a: tl.constexpr,
|
||||
num_slice_c: tl.constexpr,
|
||||
max_loras: tl.constexpr,
|
||||
top_k: tl.constexpr,
|
||||
MUL_ROUTED_WEIGHT: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
@ -104,7 +105,7 @@ def _fused_moe_lora_kernel(
|
||||
if moe_enabled == 0:
|
||||
# Early exit for the no moe lora case.
|
||||
return
|
||||
max_loras = tl.num_programs(axis=2)
|
||||
# max_loras = tl.num_programs(axis=2)
|
||||
grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)
|
||||
|
||||
# calculate pid_m,pid_n
|
||||
@ -228,6 +229,7 @@ def _fused_moe_lora_shrink(
|
||||
num_warps: int,
|
||||
num_stages: int,
|
||||
split_k: int,
|
||||
num_active_loras: int,
|
||||
mul_routed_weight: bool = False,
|
||||
) -> None:
|
||||
w1_lora_a_stacked = lora_a_stacked[0]
|
||||
@ -251,7 +253,7 @@ def _fused_moe_lora_shrink(
|
||||
* triton.cdiv(EM, META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||
len(lora_a_stacked),
|
||||
lora_a_stacked[0].shape[0],
|
||||
num_active_loras,
|
||||
)
|
||||
_fused_moe_lora_kernel[grid](
|
||||
qcurr_hidden_states,
|
||||
@ -280,6 +282,7 @@ def _fused_moe_lora_shrink(
|
||||
expert_ids.stride(0),
|
||||
slice_a_size=qcurr_hidden_states.numel(),
|
||||
slice_c_size=a_intermediate_cache1.numel() // num_slices,
|
||||
max_loras=lora_a_stacked[0].shape[0],
|
||||
num_slice_a=1,
|
||||
num_slice_c=num_slices,
|
||||
top_k=1 if mul_routed_weight else top_k_num,
|
||||
@ -322,6 +325,7 @@ def _fused_moe_lora_expand(
|
||||
num_warps: int,
|
||||
num_stages: int,
|
||||
split_k: int,
|
||||
num_active_loras: int,
|
||||
mul_routed_weight: bool = False,
|
||||
offset: int = 0,
|
||||
) -> None:
|
||||
@ -351,7 +355,7 @@ def _fused_moe_lora_expand(
|
||||
grid = lambda META: (
|
||||
triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||
len(lora_b_stacked),
|
||||
lora_b_stacked[0].shape[0],
|
||||
num_active_loras,
|
||||
)
|
||||
_fused_moe_lora_kernel[grid](
|
||||
a_intermediate_cache1,
|
||||
@ -382,6 +386,7 @@ def _fused_moe_lora_expand(
|
||||
slice_c_size=b_intermediate_cache1.numel() // num_slices,
|
||||
num_slice_a=num_slices,
|
||||
num_slice_c=num_slices,
|
||||
max_loras=lora_b_stacked[0].shape[0],
|
||||
top_k=1,
|
||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||
IS_PRIMARY=False,
|
||||
@ -492,6 +497,7 @@ def _fused_moe_lora(
|
||||
shrink_num_warps,
|
||||
shrink_num_stages,
|
||||
shrink_split_k,
|
||||
num_active_loras,
|
||||
mul_routed_weight,
|
||||
)
|
||||
|
||||
@ -538,6 +544,7 @@ def _fused_moe_lora(
|
||||
expand_num_warps,
|
||||
expand_num_stages,
|
||||
expand_split_k,
|
||||
num_active_loras,
|
||||
mul_routed_weight,
|
||||
offset,
|
||||
)
|
||||
@ -601,6 +608,7 @@ def _fused_moe_lora_shrink_fake(
|
||||
num_warps: int,
|
||||
num_stages: int,
|
||||
split_k: int,
|
||||
num_active_loras: int,
|
||||
mul_routed_weight: bool = False,
|
||||
) -> None:
|
||||
return
|
||||
@ -634,6 +642,7 @@ def _fused_moe_lora_expand_fake(
|
||||
num_warps: int,
|
||||
num_stages: int,
|
||||
split_k: int,
|
||||
num_active_loras: int,
|
||||
mul_routed_weight: bool = False,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user