diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index 6d6de2529de3d..893972144e99a 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -26,7 +26,7 @@ def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device): tensor_ptrs = [] for lora_weight in lora_weights: tensor_ptrs.append(lora_weight.data_ptr()) - ptr_tensor = torch.tensor(tensor_ptrs, device=device) + ptr_tensor = torch.tensor(tensor_ptrs, device=device, dtype=torch.uint64) _LORA_PTR_DICT[key] = ptr_tensor return _LORA_PTR_DICT.get(key) @@ -85,6 +85,7 @@ def _fused_moe_lora_kernel( GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr, USE_GDC: tl.constexpr, + launch_pdl: tl.constexpr, IS_PRIMARY: tl.constexpr, ): pid = tl.program_id(axis=0)