mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 12:55:02 +08:00
[Performance][LoRA] add context varying params to 'do_not_specialize' in fused moe lora (#27445)
Signed-off-by: gnovack <gnovack@amazon.com>
This commit is contained in:
parent
181bf5bbde
commit
a806c14cc7
@ -31,7 +31,16 @@ def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device):
|
|||||||
return _LORA_PTR_DICT.get(key)
|
return _LORA_PTR_DICT.get(key)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit(
|
||||||
|
do_not_specialize=[
|
||||||
|
"num_valid_tokens",
|
||||||
|
"EM",
|
||||||
|
"stride_tl",
|
||||||
|
"stride_el",
|
||||||
|
"slice_a_size",
|
||||||
|
"slice_c_size",
|
||||||
|
]
|
||||||
|
)
|
||||||
def _fused_moe_lora_kernel(
|
def _fused_moe_lora_kernel(
|
||||||
a_ptr,
|
a_ptr,
|
||||||
b_ptr,
|
b_ptr,
|
||||||
@ -60,11 +69,11 @@ def _fused_moe_lora_kernel(
|
|||||||
stride_cn,
|
stride_cn,
|
||||||
stride_tl,
|
stride_tl,
|
||||||
stride_el,
|
stride_el,
|
||||||
|
slice_a_size,
|
||||||
|
slice_c_size,
|
||||||
# Meta-parameters
|
# Meta-parameters
|
||||||
num_slice_a: tl.constexpr,
|
num_slice_a: tl.constexpr,
|
||||||
num_slice_c: tl.constexpr,
|
num_slice_c: tl.constexpr,
|
||||||
slice_a_size: tl.constexpr,
|
|
||||||
slice_c_size: tl.constexpr,
|
|
||||||
top_k: tl.constexpr,
|
top_k: tl.constexpr,
|
||||||
MUL_ROUTED_WEIGHT: tl.constexpr,
|
MUL_ROUTED_WEIGHT: tl.constexpr,
|
||||||
BLOCK_SIZE_M: tl.constexpr,
|
BLOCK_SIZE_M: tl.constexpr,
|
||||||
@ -256,10 +265,10 @@ def _fused_moe_lora(
|
|||||||
a_intermediate_cache1.stride(3),
|
a_intermediate_cache1.stride(3),
|
||||||
sorted_token_ids.stride(0),
|
sorted_token_ids.stride(0),
|
||||||
expert_ids.stride(0),
|
expert_ids.stride(0),
|
||||||
num_slice_a=1,
|
|
||||||
num_slice_c=num_slices,
|
|
||||||
slice_a_size=qcurr_hidden_states.numel(),
|
slice_a_size=qcurr_hidden_states.numel(),
|
||||||
slice_c_size=a_intermediate_cache1.numel() // num_slices,
|
slice_c_size=a_intermediate_cache1.numel() // num_slices,
|
||||||
|
num_slice_a=1,
|
||||||
|
num_slice_c=num_slices,
|
||||||
top_k=1 if mul_routed_weight else top_k_num,
|
top_k=1 if mul_routed_weight else top_k_num,
|
||||||
MUL_ROUTED_WEIGHT=False,
|
MUL_ROUTED_WEIGHT=False,
|
||||||
**config,
|
**config,
|
||||||
@ -305,10 +314,10 @@ def _fused_moe_lora(
|
|||||||
b_intermediate_cache1.stride(3),
|
b_intermediate_cache1.stride(3),
|
||||||
sorted_token_ids.stride(0),
|
sorted_token_ids.stride(0),
|
||||||
expert_ids.stride(0),
|
expert_ids.stride(0),
|
||||||
num_slice_a=num_slices,
|
|
||||||
num_slice_c=num_slices,
|
|
||||||
slice_a_size=a_intermediate_cache1.numel() // num_slices,
|
slice_a_size=a_intermediate_cache1.numel() // num_slices,
|
||||||
slice_c_size=b_intermediate_cache1.numel() // num_slices,
|
slice_c_size=b_intermediate_cache1.numel() // num_slices,
|
||||||
|
num_slice_a=num_slices,
|
||||||
|
num_slice_c=num_slices,
|
||||||
top_k=1,
|
top_k=1,
|
||||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||||
**config,
|
**config,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user