[V1] Ensure using int64 for sampled token ids (#15065)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-03-18 23:52:19 -07:00 committed by GitHub
parent f690372b68
commit 05ccd0aa35
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -47,6 +47,11 @@ class Sampler(nn.Module):
logits = self.apply_penalties(logits, sampling_metadata)
# Sample the next token.
sampled = self.sample(logits, sampling_metadata)
# Convert sampled token ids to int64 (long) type to ensure compatibility
# with subsequent operations that may use these values as indices.
# This conversion is necessary because FlashInfer sampling operations
# return int32 (while PyTorch argmax and topk return int64).
sampled = sampled.long()
# Gather the logprobs of the topk and sampled token (if requested).
# Get logprobs and rank tensors (if requested)
@ -139,19 +144,21 @@ class Sampler(nn.Module):
or sampled tokens (if sampled
logprobs); 1D token ID tensor
with (num tokens) elements
Must be int64.
Returns:
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
Sampled token rank tensor, (num tokens)
"""
assert token_ids.dtype == torch.int64
# Find the topK values.
topk_logprobs, topk_indices = torch.topk(logprobs,
num_logprobs,
dim=-1)
# Get with the logprob of the prompt or sampled token.
token_ids = token_ids.unsqueeze(-1).to(torch.long)
token_ids = token_ids.unsqueeze(-1)
token_logprobs = logprobs.gather(-1, token_ids)
# Compute the ranks of the actual token.