[Core] simplify logits resort in _apply_top_k_top_p (#8619)

This commit is contained in:
盏一 2024-09-20 02:28:25 +08:00 committed by GitHub
parent 9cc373f390
commit e42c634acb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -433,12 +433,9 @@ def _apply_top_k_top_p(
logits_sort.masked_fill_(top_p_mask, -float("inf"))
# Re-sort the probabilities.
src = torch.arange(logits_idx.shape[-1],
device=logits_idx.device).expand_as(logits_idx)
logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1,
index=logits_idx,
src=src)
logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv)
logits = torch.empty_like(logits_sort).scatter_(dim=-1,
index=logits_idx,
src=logits_sort)
return logits