diff --git a/vllm/v1/worker/gpu/sampler.py b/vllm/v1/worker/gpu/sampler.py index d9f74b6f09024..eaf43d4335843 100644 --- a/vllm/v1/worker/gpu/sampler.py +++ b/vllm/v1/worker/gpu/sampler.py @@ -246,16 +246,15 @@ def compute_logprobs( logprob_token_ids = torch.cat( (sampled_token_ids.unsqueeze(-1), topk_indices), dim=1) + # NOTE(woosuk): Here, to save GPU memory, we do not materialize the full + # logprobs tensor. Instead, we only compute and return the logprobs of + # the topk + 1 tokens. logprobs = torch.empty( batch_size, num_logprobs + 1, dtype=torch.float32, device=logits.device, ) - - # NOTE(woosuk): Here, to save GPU memory, we do not materialize the full - # logprobs tensor. Instead, we only compute and return the logprobs of - # the topk + 1 tokens. BLOCK_SIZE = 1024 _topk_logprobs_kernel[(batch_size, )]( logprobs,