mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-06 04:22:13 +08:00
minor
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
This commit is contained in:
parent
ad2cf805ad
commit
866eef50ca
@ -231,6 +231,31 @@ def _topk_log_softmax_kernel(
|
|||||||
tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask)
|
tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_topk_logprobs(
|
||||||
|
logits: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
batch_size, vocab_size = logits.shape
|
||||||
|
topk = topk_ids.shape[1]
|
||||||
|
output = torch.empty(
|
||||||
|
batch_size,
|
||||||
|
topk,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=logits.device,
|
||||||
|
)
|
||||||
|
_topk_log_softmax_kernel[(batch_size, )](
|
||||||
|
output,
|
||||||
|
logits,
|
||||||
|
logits.stride(0),
|
||||||
|
topk_ids,
|
||||||
|
topk,
|
||||||
|
vocab_size,
|
||||||
|
BLOCK_SIZE=1024,
|
||||||
|
PADDED_TOPK=triton.next_power_of_2(topk),
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _ranks_kernel(
|
def _ranks_kernel(
|
||||||
output_ptr,
|
output_ptr,
|
||||||
@ -273,21 +298,9 @@ def compute_logprobs(
|
|||||||
# NOTE(woosuk): Here, to save GPU memory, we do not materialize the full
|
# NOTE(woosuk): Here, to save GPU memory, we do not materialize the full
|
||||||
# logprobs tensor. Instead, we only compute and return the logprobs of
|
# logprobs tensor. Instead, we only compute and return the logprobs of
|
||||||
# the topk + 1 tokens.
|
# the topk + 1 tokens.
|
||||||
logprobs = torch.empty(
|
logprobs = compute_topk_logprobs(
|
||||||
batch_size,
|
|
||||||
num_logprobs + 1,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=logits.device,
|
|
||||||
)
|
|
||||||
_topk_log_softmax_kernel[(batch_size, )](
|
|
||||||
logprobs,
|
|
||||||
logits,
|
logits,
|
||||||
logits.stride(0),
|
|
||||||
logprob_token_ids,
|
logprob_token_ids,
|
||||||
num_logprobs + 1,
|
|
||||||
vocab_size,
|
|
||||||
BLOCK_SIZE=1024,
|
|
||||||
PADDED_TOPK=triton.next_power_of_2(num_logprobs + 1),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
token_ranks = torch.empty(
|
token_ranks = torch.empty(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user