mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-08 14:17:09 +08:00
fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
86dade710d
commit
d2be62378b
@ -199,7 +199,7 @@ def _topk_logprobs_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
topk_ids_ptr,
|
||||
k,
|
||||
topk,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
PADDED_TOPK: tl.constexpr,
|
||||
@ -224,12 +224,12 @@ def _topk_logprobs_kernel(
|
||||
lse = tl.log(se)
|
||||
|
||||
k_offset = tl.arange(0, PADDED_TOPK)
|
||||
k_mask = k_offset < k
|
||||
topk_ids = tl.load(topk_ids_ptr + req_idx * k + k_offset, mask=k_mask)
|
||||
k_mask = k_offset < topk
|
||||
topk_ids = tl.load(topk_ids_ptr + req_idx * topk + k_offset, mask=k_mask)
|
||||
|
||||
l = tl.load(row_ptr + topk_ids, mask=k_mask)
|
||||
o = l - max_val - lse
|
||||
tl.store(output_ptr + req_idx * k + k_offset, o, mask=k_mask)
|
||||
tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask)
|
||||
|
||||
|
||||
def compute_logprobs(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user