mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-08 14:07:04 +08:00
fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
d2be62378b
commit
31619ff412
@ -47,7 +47,6 @@ class Sampler(nn.Module):
|
||||
sampling_metadata.seeds,
|
||||
sampling_metadata.pos,
|
||||
)
|
||||
sampled = sampled.unsqueeze(-1)
|
||||
|
||||
logprobs_tensors = None
|
||||
num_logprobs = sampling_metadata.max_num_logprobs
|
||||
@ -63,7 +62,7 @@ class Sampler(nn.Module):
|
||||
# The sampled tokens are expanded to 2D tensor with shape
|
||||
# [num_requests, 1], where each row represents one generated
|
||||
# token per request.
|
||||
sampled_token_ids=sampled.unsqueeze(-1),
|
||||
sampled_token_ids=sampled.view(-1, 1),
|
||||
logprobs_tensors=logprobs_tensors,
|
||||
)
|
||||
return sampler_output
|
||||
@ -220,7 +219,8 @@ def _topk_logprobs_kernel(
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
l = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0)
|
||||
e = tl.exp(l - max_val)
|
||||
se += tl.sum(tl.where(block < vocab_size, e, 0.0))
|
||||
e = tl.where(block < vocab_size, e, 0.0)
|
||||
se += tl.sum(e)
|
||||
lse = tl.log(se)
|
||||
|
||||
k_offset = tl.arange(0, PADDED_TOPK)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user