diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_utils.py b/vllm/model_executor/layers/fused_moe/deep_gemm_utils.py index 8cc5a747c6731..c8469501af5db 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_utils.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_utils.py @@ -52,7 +52,7 @@ def compute_aligned_M(M: int, num_topk: int, local_num_experts: int, @triton.jit def apply_expert_map(expert_id, expert_map): if expert_id != -1: - expert_id = tl.load(expert_map + expert_id).to(tl.int64) + expert_id = tl.load(expert_map + expert_id).to(expert_id.dtype) return expert_id