From 05ccd0aa35581605017734c4ef36e058ddc58381 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 18 Mar 2025 23:52:19 -0700 Subject: [PATCH] [V1] Ensure using int64 for sampled token ids (#15065) Signed-off-by: Woosuk Kwon --- vllm/v1/sample/sampler.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 739b09fa2a258..abff7c1c2652b 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -47,6 +47,11 @@ class Sampler(nn.Module): logits = self.apply_penalties(logits, sampling_metadata) # Sample the next token. sampled = self.sample(logits, sampling_metadata) + # Convert sampled token ids to int64 (long) type to ensure compatibility + # with subsequent operations that may use these values as indices. + # This conversion is necessary because FlashInfer sampling operations + # return int32 (while PyTorch argmax and topk return int64). + sampled = sampled.long() # Gather the logprobs of the topk and sampled token (if requested). # Get logprobs and rank tensors (if requested) @@ -139,19 +144,21 @@ class Sampler(nn.Module): or sampled tokens (if sampled logprobs); 1D token ID tensor with (num tokens) elements + Must be int64. Returns: Top-k int indices tensor, (num tokens) x (num_logprobs + 1) Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1) Sampled token rank tensor, (num tokens) """ + assert token_ids.dtype == torch.int64 # Find the topK values. topk_logprobs, topk_indices = torch.topk(logprobs, num_logprobs, dim=-1) # Get with the logprob of the prompt or sampled token. - token_ids = token_ids.unsqueeze(-1).to(torch.long) + token_ids = token_ids.unsqueeze(-1) token_logprobs = logprobs.gather(-1, token_ids) # Compute the ranks of the actual token.