mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:15:01 +08:00
[Performance][LoRA] add context varying params to 'do_not_specialize' in fused moe lora (#27445)
Signed-off-by: gnovack <gnovack@amazon.com>
This commit is contained in:
parent
181bf5bbde
commit
a806c14cc7
@ -31,7 +31,16 @@ def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device):
|
||||
return _LORA_PTR_DICT.get(key)
|
||||
|
||||
|
||||
@triton.jit
|
||||
@triton.jit(
|
||||
do_not_specialize=[
|
||||
"num_valid_tokens",
|
||||
"EM",
|
||||
"stride_tl",
|
||||
"stride_el",
|
||||
"slice_a_size",
|
||||
"slice_c_size",
|
||||
]
|
||||
)
|
||||
def _fused_moe_lora_kernel(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
@ -60,11 +69,11 @@ def _fused_moe_lora_kernel(
|
||||
stride_cn,
|
||||
stride_tl,
|
||||
stride_el,
|
||||
slice_a_size,
|
||||
slice_c_size,
|
||||
# Meta-parameters
|
||||
num_slice_a: tl.constexpr,
|
||||
num_slice_c: tl.constexpr,
|
||||
slice_a_size: tl.constexpr,
|
||||
slice_c_size: tl.constexpr,
|
||||
top_k: tl.constexpr,
|
||||
MUL_ROUTED_WEIGHT: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
@ -256,10 +265,10 @@ def _fused_moe_lora(
|
||||
a_intermediate_cache1.stride(3),
|
||||
sorted_token_ids.stride(0),
|
||||
expert_ids.stride(0),
|
||||
num_slice_a=1,
|
||||
num_slice_c=num_slices,
|
||||
slice_a_size=qcurr_hidden_states.numel(),
|
||||
slice_c_size=a_intermediate_cache1.numel() // num_slices,
|
||||
num_slice_a=1,
|
||||
num_slice_c=num_slices,
|
||||
top_k=1 if mul_routed_weight else top_k_num,
|
||||
MUL_ROUTED_WEIGHT=False,
|
||||
**config,
|
||||
@ -305,10 +314,10 @@ def _fused_moe_lora(
|
||||
b_intermediate_cache1.stride(3),
|
||||
sorted_token_ids.stride(0),
|
||||
expert_ids.stride(0),
|
||||
num_slice_a=num_slices,
|
||||
num_slice_c=num_slices,
|
||||
slice_a_size=a_intermediate_cache1.numel() // num_slices,
|
||||
slice_c_size=b_intermediate_cache1.numel() // num_slices,
|
||||
num_slice_a=num_slices,
|
||||
num_slice_c=num_slices,
|
||||
top_k=1,
|
||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||
**config,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user