mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 23:45:39 +08:00
[BugFix] fix group_topk (#8430)
This commit is contained in:
parent
360ddbd37e
commit
8f44a92d85
@ -410,6 +410,7 @@ def fused_topk(
|
|||||||
|
|
||||||
if renormalize:
|
if renormalize:
|
||||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
return topk_weights, topk_ids
|
return topk_weights, topk_ids
|
||||||
|
|
||||||
|
|
||||||
@ -443,7 +444,8 @@ def grouped_topk(hidden_states: torch.Tensor,
|
|||||||
|
|
||||||
if renormalize:
|
if renormalize:
|
||||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||||
return topk_weights, topk_ids
|
|
||||||
|
return topk_weights, topk_ids.to(torch.int32)
|
||||||
|
|
||||||
|
|
||||||
def get_config_dtype_str(dtype: torch.dtype,
|
def get_config_dtype_str(dtype: torch.dtype,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user