Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-09-18 16:33:18 -07:00
parent 86dade710d
commit d2be62378b

View File

@ -199,7 +199,7 @@ def _topk_logprobs_kernel(
logits_ptr,
logits_stride,
topk_ids_ptr,
k,
topk,
vocab_size,
BLOCK_SIZE: tl.constexpr,
PADDED_TOPK: tl.constexpr,
@ -224,12 +224,12 @@ def _topk_logprobs_kernel(
lse = tl.log(se)
k_offset = tl.arange(0, PADDED_TOPK)
k_mask = k_offset < k
topk_ids = tl.load(topk_ids_ptr + req_idx * k + k_offset, mask=k_mask)
k_mask = k_offset < topk
topk_ids = tl.load(topk_ids_ptr + req_idx * topk + k_offset, mask=k_mask)
l = tl.load(row_ptr + topk_ids, mask=k_mask)
o = l - max_val - lse
tl.store(output_ptr + req_idx * k + k_offset, o, mask=k_mask)
tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask)
def compute_logprobs(