Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-09-18 16:38:56 -07:00
parent d2be62378b
commit 31619ff412

View File

@ -47,7 +47,6 @@ class Sampler(nn.Module):
sampling_metadata.seeds,
sampling_metadata.pos,
)
sampled = sampled.unsqueeze(-1)
logprobs_tensors = None
num_logprobs = sampling_metadata.max_num_logprobs
@ -63,7 +62,7 @@ class Sampler(nn.Module):
# The sampled tokens are expanded to 2D tensor with shape
# [num_requests, 1], where each row represents one generated
# token per request.
sampled_token_ids=sampled.unsqueeze(-1),
sampled_token_ids=sampled.view(-1, 1),
logprobs_tensors=logprobs_tensors,
)
return sampler_output
@ -220,7 +219,8 @@ def _topk_logprobs_kernel(
block = i + tl.arange(0, BLOCK_SIZE)
l = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0)
e = tl.exp(l - max_val)
se += tl.sum(tl.where(block < vocab_size, e, 0.0))
e = tl.where(block < vocab_size, e, 0.0)
se += tl.sum(e)
lse = tl.log(se)
k_offset = tl.arange(0, PADDED_TOPK)