mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 20:45:49 +08:00
fix long dtype in topk sampling (#15049)
This commit is contained in:
parent
72a8639b68
commit
027827cc1d
@ -151,7 +151,7 @@ class Sampler(nn.Module):
|
|||||||
dim=-1)
|
dim=-1)
|
||||||
|
|
||||||
# Get with the logprob of the prompt or sampled token.
|
# Get with the logprob of the prompt or sampled token.
|
||||||
token_ids = token_ids.unsqueeze(-1)
|
token_ids = token_ids.unsqueeze(-1).to(torch.long)
|
||||||
token_logprobs = logprobs.gather(-1, token_ids)
|
token_logprobs = logprobs.gather(-1, token_ids)
|
||||||
|
|
||||||
# Compute the ranks of the actual token.
|
# Compute the ranks of the actual token.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user