[Performance][LoRA] add context varying params to 'do_not_specialize' in fused moe lora (#27445)

Signed-off-by: gnovack <gnovack@amazon.com>
This commit is contained in:
gnovack 2025-10-26 23:31:55 -07:00 committed by GitHub
parent 181bf5bbde
commit a806c14cc7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -31,7 +31,16 @@ def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device):
return _LORA_PTR_DICT.get(key)
@triton.jit
@triton.jit(
do_not_specialize=[
"num_valid_tokens",
"EM",
"stride_tl",
"stride_el",
"slice_a_size",
"slice_c_size",
]
)
def _fused_moe_lora_kernel(
a_ptr,
b_ptr,
@ -60,11 +69,11 @@ def _fused_moe_lora_kernel(
stride_cn,
stride_tl,
stride_el,
slice_a_size,
slice_c_size,
# Meta-parameters
num_slice_a: tl.constexpr,
num_slice_c: tl.constexpr,
slice_a_size: tl.constexpr,
slice_c_size: tl.constexpr,
top_k: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
@ -256,10 +265,10 @@ def _fused_moe_lora(
a_intermediate_cache1.stride(3),
sorted_token_ids.stride(0),
expert_ids.stride(0),
num_slice_a=1,
num_slice_c=num_slices,
slice_a_size=qcurr_hidden_states.numel(),
slice_c_size=a_intermediate_cache1.numel() // num_slices,
num_slice_a=1,
num_slice_c=num_slices,
top_k=1 if mul_routed_weight else top_k_num,
MUL_ROUTED_WEIGHT=False,
**config,
@ -305,10 +314,10 @@ def _fused_moe_lora(
b_intermediate_cache1.stride(3),
sorted_token_ids.stride(0),
expert_ids.stride(0),
num_slice_a=num_slices,
num_slice_c=num_slices,
slice_a_size=a_intermediate_cache1.numel() // num_slices,
slice_c_size=b_intermediate_cache1.numel() // num_slices,
num_slice_a=num_slices,
num_slice_c=num_slices,
top_k=1,
MUL_ROUTED_WEIGHT=mul_routed_weight,
**config,