Implement topk_logprobs

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-09-18 16:29:38 -07:00
parent 323a05b3c5
commit 82da219ff9

View File

@ -47,16 +47,13 @@ class Sampler(nn.Module):
sampling_metadata.seeds,
sampling_metadata.pos,
)
sampled = sampled.unsqueeze(-1)
logprobs_tensors = None
num_logprobs = sampling_metadata.max_num_logprobs
if num_logprobs is not None:
assert num_logprobs >= 0
# Compute the logprobs.
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float32)
# Gather the logprobs of the topk and sampled token.
logprobs_tensors = self.gather_logprobs(
logprobs,
logprobs_tensors = compute_logprobs(
logits,
num_logprobs,
sampled,
)
@ -71,36 +68,6 @@ class Sampler(nn.Module):
)
return sampler_output
def gather_logprobs(
self,
logprobs: torch.Tensor,
num_logprobs: int,
sampled: torch.Tensor,
) -> LogprobsTensors:
sampled = sampled.unsqueeze(-1)
sampled_logprobs = logprobs.gather(-1, sampled)
sampled_ranks = (logprobs > sampled_logprobs).sum(-1)
if num_logprobs == 0:
# Return the logprobs of the sampled token.
logprobs_tensors = LogprobsTensors(
sampled,
sampled_logprobs,
sampled_ranks,
)
else:
# Return (num_logprobs + 1) logprobs.
topk_logprobs, topk_indices = torch.topk(
logprobs,
num_logprobs,
dim=-1,
)
logprobs_tensors = LogprobsTensors(
torch.cat((sampled, topk_indices), dim=1),
torch.cat((sampled_logprobs, topk_logprobs), dim=1),
sampled_ranks,
)
return logprobs_tensors
@triton.jit
def _apply_temp_kernel(
@ -224,3 +191,84 @@ def gumbel_sample(
)
# Sample the next token.
return probs.argmax(dim=-1).view(-1)
@triton.jit
def _topk_logprobs_kernel(
output_ptr,
logits_ptr,
logits_stride,
topk_ids_ptr,
k,
vocab_size,
BLOCK_SIZE: tl.constexpr,
PADDED_TOPK: tl.constexpr,
):
req_idx = tl.program_id(0)
row_ptr = logits_ptr + req_idx * logits_stride
max_val = -float("inf")
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
l = tl.load(row_ptr + block,
mask=block < vocab_size,
other=-float("inf"))
max_val = tl.max(tl.maximum(l, max_val))
se = 0.0
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
l = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0)
e = tl.exp(l - max_val)
se += tl.sum(tl.where(block < vocab_size, e, 0.0))
lse = tl.log(se)
k_offset = tl.arange(0, PADDED_TOPK)
k_mask = k_offset < k
topk_ids = tl.load(topk_ids_ptr + req_idx * k + k_offset, mask=k_mask)
l = tl.load(row_ptr + topk_ids, mask=k_mask)
o = l - max_val - lse
tl.store(output_ptr + req_idx * k + k_offset, o, mask=k_mask)
def compute_logprobs(
logits: torch.Tensor,
num_logprobs: int,
sampled_token_ids: torch.Tensor,
) -> LogprobsTensors:
assert num_logprobs >= 0
batch_size, vocab_size = logits.shape
if num_logprobs == 0:
logprob_token_ids = sampled_token_ids.unsqueeze(-1)
else:
topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices
logprob_token_ids = torch.cat(
(sampled_token_ids.unsqueeze(-1), topk_indices), dim=1)
logprobs = torch.empty(
batch_size,
num_logprobs + 1,
dtype=torch.float32,
device=logits.device,
)
# NOTE(woosuk): Here, to save GPU memory, we do not materialize the full
# logprobs tensor. Instead, we only compute the logprobs of the topk + 1
# tokens.
BLOCK_SIZE = 1024
_topk_logprobs_kernel[(batch_size, )](
logprobs,
logits,
logits.stride(0),
logprob_token_ids,
num_logprobs + 1,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
PADDED_TOPK=triton.next_power_of_2(num_logprobs + 1),
)
return LogprobsTensors(
logprob_token_ids=logprob_token_ids,
logprobs=logprobs,
selected_token_ranks=None, # TODO
)