From 82da219ff9630b5521af0443e138d43c23cec7c6 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 18 Sep 2025 16:29:38 -0700 Subject: [PATCH] Implement topk_logprobs Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu/sampler.py | 120 ++++++++++++++++++++++++---------- 1 file changed, 84 insertions(+), 36 deletions(-) diff --git a/vllm/v1/worker/gpu/sampler.py b/vllm/v1/worker/gpu/sampler.py index 095aa233a4d2c..cee666df22a52 100644 --- a/vllm/v1/worker/gpu/sampler.py +++ b/vllm/v1/worker/gpu/sampler.py @@ -47,16 +47,13 @@ class Sampler(nn.Module): sampling_metadata.seeds, sampling_metadata.pos, ) + sampled = sampled.unsqueeze(-1) logprobs_tensors = None num_logprobs = sampling_metadata.max_num_logprobs if num_logprobs is not None: - assert num_logprobs >= 0 - # Compute the logprobs. - logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float32) - # Gather the logprobs of the topk and sampled token. - logprobs_tensors = self.gather_logprobs( - logprobs, + logprobs_tensors = compute_logprobs( + logits, num_logprobs, sampled, ) @@ -71,36 +68,6 @@ class Sampler(nn.Module): ) return sampler_output - def gather_logprobs( - self, - logprobs: torch.Tensor, - num_logprobs: int, - sampled: torch.Tensor, - ) -> LogprobsTensors: - sampled = sampled.unsqueeze(-1) - sampled_logprobs = logprobs.gather(-1, sampled) - sampled_ranks = (logprobs > sampled_logprobs).sum(-1) - if num_logprobs == 0: - # Return the logprobs of the sampled token. - logprobs_tensors = LogprobsTensors( - sampled, - sampled_logprobs, - sampled_ranks, - ) - else: - # Return (num_logprobs + 1) logprobs. - topk_logprobs, topk_indices = torch.topk( - logprobs, - num_logprobs, - dim=-1, - ) - logprobs_tensors = LogprobsTensors( - torch.cat((sampled, topk_indices), dim=1), - torch.cat((sampled_logprobs, topk_logprobs), dim=1), - sampled_ranks, - ) - return logprobs_tensors - @triton.jit def _apply_temp_kernel( @@ -224,3 +191,84 @@ def gumbel_sample( ) # Sample the next token. return probs.argmax(dim=-1).view(-1) + + +@triton.jit +def _topk_logprobs_kernel( + output_ptr, + logits_ptr, + logits_stride, + topk_ids_ptr, + k, + vocab_size, + BLOCK_SIZE: tl.constexpr, + PADDED_TOPK: tl.constexpr, +): + req_idx = tl.program_id(0) + row_ptr = logits_ptr + req_idx * logits_stride + + max_val = -float("inf") + for i in range(0, vocab_size, BLOCK_SIZE): + block = i + tl.arange(0, BLOCK_SIZE) + l = tl.load(row_ptr + block, + mask=block < vocab_size, + other=-float("inf")) + max_val = tl.max(tl.maximum(l, max_val)) + + se = 0.0 + for i in range(0, vocab_size, BLOCK_SIZE): + block = i + tl.arange(0, BLOCK_SIZE) + l = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0) + e = tl.exp(l - max_val) + se += tl.sum(tl.where(block < vocab_size, e, 0.0)) + lse = tl.log(se) + + k_offset = tl.arange(0, PADDED_TOPK) + k_mask = k_offset < k + topk_ids = tl.load(topk_ids_ptr + req_idx * k + k_offset, mask=k_mask) + + l = tl.load(row_ptr + topk_ids, mask=k_mask) + o = l - max_val - lse + tl.store(output_ptr + req_idx * k + k_offset, o, mask=k_mask) + + +def compute_logprobs( + logits: torch.Tensor, + num_logprobs: int, + sampled_token_ids: torch.Tensor, +) -> LogprobsTensors: + assert num_logprobs >= 0 + batch_size, vocab_size = logits.shape + if num_logprobs == 0: + logprob_token_ids = sampled_token_ids.unsqueeze(-1) + else: + topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices + logprob_token_ids = torch.cat( + (sampled_token_ids.unsqueeze(-1), topk_indices), dim=1) + + logprobs = torch.empty( + batch_size, + num_logprobs + 1, + dtype=torch.float32, + device=logits.device, + ) + + # NOTE(woosuk): Here, to save GPU memory, we do not materialize the full + # logprobs tensor. Instead, we only compute the logprobs of the topk + 1 + # tokens. + BLOCK_SIZE = 1024 + _topk_logprobs_kernel[(batch_size, )]( + logprobs, + logits, + logits.stride(0), + logprob_token_ids, + num_logprobs + 1, + vocab_size, + BLOCK_SIZE=BLOCK_SIZE, + PADDED_TOPK=triton.next_power_of_2(num_logprobs + 1), + ) + return LogprobsTensors( + logprob_token_ids=logprob_token_ids, + logprobs=logprobs, + selected_token_ranks=None, # TODO + )