mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-24 00:04:32 +08:00
fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
86dade710d
commit
d2be62378b
@ -199,7 +199,7 @@ def _topk_logprobs_kernel(
|
|||||||
logits_ptr,
|
logits_ptr,
|
||||||
logits_stride,
|
logits_stride,
|
||||||
topk_ids_ptr,
|
topk_ids_ptr,
|
||||||
k,
|
topk,
|
||||||
vocab_size,
|
vocab_size,
|
||||||
BLOCK_SIZE: tl.constexpr,
|
BLOCK_SIZE: tl.constexpr,
|
||||||
PADDED_TOPK: tl.constexpr,
|
PADDED_TOPK: tl.constexpr,
|
||||||
@ -224,12 +224,12 @@ def _topk_logprobs_kernel(
|
|||||||
lse = tl.log(se)
|
lse = tl.log(se)
|
||||||
|
|
||||||
k_offset = tl.arange(0, PADDED_TOPK)
|
k_offset = tl.arange(0, PADDED_TOPK)
|
||||||
k_mask = k_offset < k
|
k_mask = k_offset < topk
|
||||||
topk_ids = tl.load(topk_ids_ptr + req_idx * k + k_offset, mask=k_mask)
|
topk_ids = tl.load(topk_ids_ptr + req_idx * topk + k_offset, mask=k_mask)
|
||||||
|
|
||||||
l = tl.load(row_ptr + topk_ids, mask=k_mask)
|
l = tl.load(row_ptr + topk_ids, mask=k_mask)
|
||||||
o = l - max_val - lse
|
o = l - max_val - lse
|
||||||
tl.store(output_ptr + req_idx * k + k_offset, o, mask=k_mask)
|
tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask)
|
||||||
|
|
||||||
|
|
||||||
def compute_logprobs(
|
def compute_logprobs(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user