mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:34:57 +08:00
[Bugfix] Fix adapter_enabled IMA (#29977)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
5f91cdda75
commit
dd38ba3a26
@ -96,10 +96,14 @@ def _fused_moe_lora_kernel(
|
|||||||
slice_id = tl.program_id(axis=1)
|
slice_id = tl.program_id(axis=1)
|
||||||
lora_idx = tl.program_id(axis=2)
|
lora_idx = tl.program_id(axis=2)
|
||||||
lora_id = tl.load(lora_ids + lora_idx)
|
lora_id = tl.load(lora_ids + lora_idx)
|
||||||
moe_enabled = tl.load(adapter_enabled + lora_id)
|
|
||||||
if lora_id == -1 or moe_enabled == 0:
|
if lora_id == -1:
|
||||||
# Early exit for the no-lora case.
|
# Early exit for the no-lora case.
|
||||||
return
|
return
|
||||||
|
moe_enabled = tl.load(adapter_enabled + lora_id)
|
||||||
|
if moe_enabled == 0:
|
||||||
|
# Early exit for the no moe lora case.
|
||||||
|
return
|
||||||
max_loras = tl.num_programs(axis=2)
|
max_loras = tl.num_programs(axis=2)
|
||||||
grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)
|
grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user