[BugFix] fix group_topk (#8430)

This commit is contained in:
Dipika Sikka 2024-09-12 21:23:42 -04:00 committed by GitHub
parent 360ddbd37e
commit 8f44a92d85
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -410,6 +410,7 @@ def fused_topk(
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
@ -443,7 +444,8 @@ def grouped_topk(hidden_states: torch.Tensor,
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
return topk_weights, topk_ids.to(torch.int32)
def get_config_dtype_str(dtype: torch.dtype,