Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-09-18 16:33:18 -07:00
parent 86dade710d
commit d2be62378b

View File

@ -199,7 +199,7 @@ def _topk_logprobs_kernel(
logits_ptr, logits_ptr,
logits_stride, logits_stride,
topk_ids_ptr, topk_ids_ptr,
k, topk,
vocab_size, vocab_size,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
PADDED_TOPK: tl.constexpr, PADDED_TOPK: tl.constexpr,
@ -224,12 +224,12 @@ def _topk_logprobs_kernel(
lse = tl.log(se) lse = tl.log(se)
k_offset = tl.arange(0, PADDED_TOPK) k_offset = tl.arange(0, PADDED_TOPK)
k_mask = k_offset < k k_mask = k_offset < topk
topk_ids = tl.load(topk_ids_ptr + req_idx * k + k_offset, mask=k_mask) topk_ids = tl.load(topk_ids_ptr + req_idx * topk + k_offset, mask=k_mask)
l = tl.load(row_ptr + topk_ids, mask=k_mask) l = tl.load(row_ptr + topk_ids, mask=k_mask)
o = l - max_val - lse o = l - max_val - lse
tl.store(output_ptr + req_idx * k + k_offset, o, mask=k_mask) tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask)
def compute_logprobs( def compute_logprobs(