mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 12:25:32 +08:00
[BugFix] fix group_topk (#8430)
This commit is contained in:
parent
360ddbd37e
commit
8f44a92d85
@ -410,6 +410,7 @@ def fused_topk(
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
@ -443,7 +444,8 @@ def grouped_topk(hidden_states: torch.Tensor,
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
return topk_weights, topk_ids.to(torch.int32)
|
||||
|
||||
|
||||
def get_config_dtype_str(dtype: torch.dtype,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user