mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-05 03:27:02 +08:00
[V1] Ensure using int64 for sampled token ids (#15065)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
f690372b68
commit
05ccd0aa35
@ -47,6 +47,11 @@ class Sampler(nn.Module):
|
||||
logits = self.apply_penalties(logits, sampling_metadata)
|
||||
# Sample the next token.
|
||||
sampled = self.sample(logits, sampling_metadata)
|
||||
# Convert sampled token ids to int64 (long) type to ensure compatibility
|
||||
# with subsequent operations that may use these values as indices.
|
||||
# This conversion is necessary because FlashInfer sampling operations
|
||||
# return int32 (while PyTorch argmax and topk return int64).
|
||||
sampled = sampled.long()
|
||||
|
||||
# Gather the logprobs of the topk and sampled token (if requested).
|
||||
# Get logprobs and rank tensors (if requested)
|
||||
@ -139,19 +144,21 @@ class Sampler(nn.Module):
|
||||
or sampled tokens (if sampled
|
||||
logprobs); 1D token ID tensor
|
||||
with (num tokens) elements
|
||||
Must be int64.
|
||||
|
||||
Returns:
|
||||
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
|
||||
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
|
||||
Sampled token rank tensor, (num tokens)
|
||||
"""
|
||||
assert token_ids.dtype == torch.int64
|
||||
# Find the topK values.
|
||||
topk_logprobs, topk_indices = torch.topk(logprobs,
|
||||
num_logprobs,
|
||||
dim=-1)
|
||||
|
||||
# Get with the logprob of the prompt or sampled token.
|
||||
token_ids = token_ids.unsqueeze(-1).to(torch.long)
|
||||
token_ids = token_ids.unsqueeze(-1)
|
||||
token_logprobs = logprobs.gather(-1, token_ids)
|
||||
|
||||
# Compute the ranks of the actual token.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user