From 027827cc1d8d30d6e0108045ac9f9209a89dc59d Mon Sep 17 00:00:00 2001 From: Chujie Zheng Date: Wed, 19 Mar 2025 06:57:31 +0800 Subject: [PATCH] fix long dtype in topk sampling (#15049) --- vllm/v1/sample/sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index d91c057083f3..739b09fa2a25 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -151,7 +151,7 @@ class Sampler(nn.Module): dim=-1) # Get with the logprob of the prompt or sampled token. - token_ids = token_ids.unsqueeze(-1) + token_ids = token_ids.unsqueeze(-1).to(torch.long) token_logprobs = logprobs.gather(-1, token_ids) # Compute the ranks of the actual token.