diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index d91c057083f3..739b09fa2a25 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -151,7 +151,7 @@ class Sampler(nn.Module): dim=-1) # Get with the logprob of the prompt or sampled token. - token_ids = token_ids.unsqueeze(-1) + token_ids = token_ids.unsqueeze(-1).to(torch.long) token_logprobs = logprobs.gather(-1, token_ids) # Compute the ranks of the actual token.