mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-13 00:50:11 +08:00
fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
efda08481b
commit
86dade710d
@ -246,16 +246,15 @@ def compute_logprobs(
|
|||||||
logprob_token_ids = torch.cat(
|
logprob_token_ids = torch.cat(
|
||||||
(sampled_token_ids.unsqueeze(-1), topk_indices), dim=1)
|
(sampled_token_ids.unsqueeze(-1), topk_indices), dim=1)
|
||||||
|
|
||||||
|
# NOTE(woosuk): Here, to save GPU memory, we do not materialize the full
|
||||||
|
# logprobs tensor. Instead, we only compute and return the logprobs of
|
||||||
|
# the topk + 1 tokens.
|
||||||
logprobs = torch.empty(
|
logprobs = torch.empty(
|
||||||
batch_size,
|
batch_size,
|
||||||
num_logprobs + 1,
|
num_logprobs + 1,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=logits.device,
|
device=logits.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
# NOTE(woosuk): Here, to save GPU memory, we do not materialize the full
|
|
||||||
# logprobs tensor. Instead, we only compute and return the logprobs of
|
|
||||||
# the topk + 1 tokens.
|
|
||||||
BLOCK_SIZE = 1024
|
BLOCK_SIZE = 1024
|
||||||
_topk_logprobs_kernel[(batch_size, )](
|
_topk_logprobs_kernel[(batch_size, )](
|
||||||
logprobs,
|
logprobs,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user