Fix Fused MoE LoRA Triton kernel bug (#28450)

Signed-off-by: chaojun-zhang <chaojun.zhang@intel.com>
This commit is contained in:
Chaojun Zhang 2025-11-11 20:46:47 +08:00 committed by GitHub
parent b30dfa03c5
commit 7dbe6d81d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -26,7 +26,7 @@ def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device):
tensor_ptrs = []
for lora_weight in lora_weights:
tensor_ptrs.append(lora_weight.data_ptr())
ptr_tensor = torch.tensor(tensor_ptrs, device=device)
ptr_tensor = torch.tensor(tensor_ptrs, device=device, dtype=torch.uint64)
_LORA_PTR_DICT[key] = ptr_tensor
return _LORA_PTR_DICT.get(key)
@ -85,6 +85,7 @@ def _fused_moe_lora_kernel(
GROUP_SIZE_M: tl.constexpr,
SPLIT_K: tl.constexpr,
USE_GDC: tl.constexpr,
launch_pdl: tl.constexpr,
IS_PRIMARY: tl.constexpr,
):
pid = tl.program_id(axis=0)