Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-09-18 16:32:00 -07:00
parent efda08481b
commit 86dade710d

View File

@ -246,16 +246,15 @@ def compute_logprobs(
logprob_token_ids = torch.cat(
(sampled_token_ids.unsqueeze(-1), topk_indices), dim=1)
# NOTE(woosuk): Here, to save GPU memory, we do not materialize the full
# logprobs tensor. Instead, we only compute and return the logprobs of
# the topk + 1 tokens.
logprobs = torch.empty(
batch_size,
num_logprobs + 1,
dtype=torch.float32,
device=logits.device,
)
# NOTE(woosuk): Here, to save GPU memory, we do not materialize the full
# logprobs tensor. Instead, we only compute and return the logprobs of
# the topk + 1 tokens.
BLOCK_SIZE = 1024
_topk_logprobs_kernel[(batch_size, )](
logprobs,