[Bugfix] DeepGemm utils : Fix hardcoded type-cast (#21517)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath 2025-07-25 08:47:29 +05:30 committed by GitHub
parent ce3a9b1378
commit 2212cd6cfb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -52,7 +52,7 @@ def compute_aligned_M(M: int, num_topk: int, local_num_experts: int,
@triton.jit
def apply_expert_map(expert_id, expert_map):
if expert_id != -1:
expert_id = tl.load(expert_map + expert_id).to(tl.int64)
expert_id = tl.load(expert_map + expert_id).to(expert_id.dtype)
return expert_id