From 866eef50cae7f9a5f10dbbad8cdf34d07c943f1b Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 24 Sep 2025 15:29:27 +0000 Subject: [PATCH] minor Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu/sampler.py | 39 +++++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/vllm/v1/worker/gpu/sampler.py b/vllm/v1/worker/gpu/sampler.py index aa5f8446602fa..65aadf9654e92 100644 --- a/vllm/v1/worker/gpu/sampler.py +++ b/vllm/v1/worker/gpu/sampler.py @@ -231,6 +231,31 @@ def _topk_log_softmax_kernel( tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask) +def compute_topk_logprobs( + logits: torch.Tensor, + topk_ids: torch.Tensor, +) -> torch.Tensor: + batch_size, vocab_size = logits.shape + topk = topk_ids.shape[1] + output = torch.empty( + batch_size, + topk, + dtype=torch.float32, + device=logits.device, + ) + _topk_log_softmax_kernel[(batch_size, )]( + output, + logits, + logits.stride(0), + topk_ids, + topk, + vocab_size, + BLOCK_SIZE=1024, + PADDED_TOPK=triton.next_power_of_2(topk), + ) + return output + + @triton.jit def _ranks_kernel( output_ptr, @@ -273,21 +298,9 @@ def compute_logprobs( # NOTE(woosuk): Here, to save GPU memory, we do not materialize the full # logprobs tensor. Instead, we only compute and return the logprobs of # the topk + 1 tokens. - logprobs = torch.empty( - batch_size, - num_logprobs + 1, - dtype=torch.float32, - device=logits.device, - ) - _topk_log_softmax_kernel[(batch_size, )]( - logprobs, + logprobs = compute_topk_logprobs( logits, - logits.stride(0), logprob_token_ids, - num_logprobs + 1, - vocab_size, - BLOCK_SIZE=1024, - PADDED_TOPK=triton.next_power_of_2(num_logprobs + 1), ) token_ranks = torch.empty(