diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index 94935d8dfe86..d8746ebc8e75 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -31,7 +31,16 @@ def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device): return _LORA_PTR_DICT.get(key) -@triton.jit +@triton.jit( + do_not_specialize=[ + "num_valid_tokens", + "EM", + "stride_tl", + "stride_el", + "slice_a_size", + "slice_c_size", + ] +) def _fused_moe_lora_kernel( a_ptr, b_ptr, @@ -60,11 +69,11 @@ def _fused_moe_lora_kernel( stride_cn, stride_tl, stride_el, + slice_a_size, + slice_c_size, # Meta-parameters num_slice_a: tl.constexpr, num_slice_c: tl.constexpr, - slice_a_size: tl.constexpr, - slice_c_size: tl.constexpr, top_k: tl.constexpr, MUL_ROUTED_WEIGHT: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, @@ -256,10 +265,10 @@ def _fused_moe_lora( a_intermediate_cache1.stride(3), sorted_token_ids.stride(0), expert_ids.stride(0), - num_slice_a=1, - num_slice_c=num_slices, slice_a_size=qcurr_hidden_states.numel(), slice_c_size=a_intermediate_cache1.numel() // num_slices, + num_slice_a=1, + num_slice_c=num_slices, top_k=1 if mul_routed_weight else top_k_num, MUL_ROUTED_WEIGHT=False, **config, @@ -305,10 +314,10 @@ def _fused_moe_lora( b_intermediate_cache1.stride(3), sorted_token_ids.stride(0), expert_ids.stride(0), - num_slice_a=num_slices, - num_slice_c=num_slices, slice_a_size=a_intermediate_cache1.numel() // num_slices, slice_c_size=b_intermediate_cache1.numel() // num_slices, + num_slice_a=num_slices, + num_slice_c=num_slices, top_k=1, MUL_ROUTED_WEIGHT=mul_routed_weight, **config,