mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-08 14:17:09 +08:00
Implement topk_logprobs
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
323a05b3c5
commit
82da219ff9
@ -47,16 +47,13 @@ class Sampler(nn.Module):
|
||||
sampling_metadata.seeds,
|
||||
sampling_metadata.pos,
|
||||
)
|
||||
sampled = sampled.unsqueeze(-1)
|
||||
|
||||
logprobs_tensors = None
|
||||
num_logprobs = sampling_metadata.max_num_logprobs
|
||||
if num_logprobs is not None:
|
||||
assert num_logprobs >= 0
|
||||
# Compute the logprobs.
|
||||
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float32)
|
||||
# Gather the logprobs of the topk and sampled token.
|
||||
logprobs_tensors = self.gather_logprobs(
|
||||
logprobs,
|
||||
logprobs_tensors = compute_logprobs(
|
||||
logits,
|
||||
num_logprobs,
|
||||
sampled,
|
||||
)
|
||||
@ -71,36 +68,6 @@ class Sampler(nn.Module):
|
||||
)
|
||||
return sampler_output
|
||||
|
||||
def gather_logprobs(
|
||||
self,
|
||||
logprobs: torch.Tensor,
|
||||
num_logprobs: int,
|
||||
sampled: torch.Tensor,
|
||||
) -> LogprobsTensors:
|
||||
sampled = sampled.unsqueeze(-1)
|
||||
sampled_logprobs = logprobs.gather(-1, sampled)
|
||||
sampled_ranks = (logprobs > sampled_logprobs).sum(-1)
|
||||
if num_logprobs == 0:
|
||||
# Return the logprobs of the sampled token.
|
||||
logprobs_tensors = LogprobsTensors(
|
||||
sampled,
|
||||
sampled_logprobs,
|
||||
sampled_ranks,
|
||||
)
|
||||
else:
|
||||
# Return (num_logprobs + 1) logprobs.
|
||||
topk_logprobs, topk_indices = torch.topk(
|
||||
logprobs,
|
||||
num_logprobs,
|
||||
dim=-1,
|
||||
)
|
||||
logprobs_tensors = LogprobsTensors(
|
||||
torch.cat((sampled, topk_indices), dim=1),
|
||||
torch.cat((sampled_logprobs, topk_logprobs), dim=1),
|
||||
sampled_ranks,
|
||||
)
|
||||
return logprobs_tensors
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _apply_temp_kernel(
|
||||
@ -224,3 +191,84 @@ def gumbel_sample(
|
||||
)
|
||||
# Sample the next token.
|
||||
return probs.argmax(dim=-1).view(-1)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _topk_logprobs_kernel(
|
||||
output_ptr,
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
topk_ids_ptr,
|
||||
k,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
PADDED_TOPK: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
row_ptr = logits_ptr + req_idx * logits_stride
|
||||
|
||||
max_val = -float("inf")
|
||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
l = tl.load(row_ptr + block,
|
||||
mask=block < vocab_size,
|
||||
other=-float("inf"))
|
||||
max_val = tl.max(tl.maximum(l, max_val))
|
||||
|
||||
se = 0.0
|
||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
l = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0)
|
||||
e = tl.exp(l - max_val)
|
||||
se += tl.sum(tl.where(block < vocab_size, e, 0.0))
|
||||
lse = tl.log(se)
|
||||
|
||||
k_offset = tl.arange(0, PADDED_TOPK)
|
||||
k_mask = k_offset < k
|
||||
topk_ids = tl.load(topk_ids_ptr + req_idx * k + k_offset, mask=k_mask)
|
||||
|
||||
l = tl.load(row_ptr + topk_ids, mask=k_mask)
|
||||
o = l - max_val - lse
|
||||
tl.store(output_ptr + req_idx * k + k_offset, o, mask=k_mask)
|
||||
|
||||
|
||||
def compute_logprobs(
|
||||
logits: torch.Tensor,
|
||||
num_logprobs: int,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
) -> LogprobsTensors:
|
||||
assert num_logprobs >= 0
|
||||
batch_size, vocab_size = logits.shape
|
||||
if num_logprobs == 0:
|
||||
logprob_token_ids = sampled_token_ids.unsqueeze(-1)
|
||||
else:
|
||||
topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices
|
||||
logprob_token_ids = torch.cat(
|
||||
(sampled_token_ids.unsqueeze(-1), topk_indices), dim=1)
|
||||
|
||||
logprobs = torch.empty(
|
||||
batch_size,
|
||||
num_logprobs + 1,
|
||||
dtype=torch.float32,
|
||||
device=logits.device,
|
||||
)
|
||||
|
||||
# NOTE(woosuk): Here, to save GPU memory, we do not materialize the full
|
||||
# logprobs tensor. Instead, we only compute the logprobs of the topk + 1
|
||||
# tokens.
|
||||
BLOCK_SIZE = 1024
|
||||
_topk_logprobs_kernel[(batch_size, )](
|
||||
logprobs,
|
||||
logits,
|
||||
logits.stride(0),
|
||||
logprob_token_ids,
|
||||
num_logprobs + 1,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
PADDED_TOPK=triton.next_power_of_2(num_logprobs + 1),
|
||||
)
|
||||
return LogprobsTensors(
|
||||
logprob_token_ids=logprob_token_ids,
|
||||
logprobs=logprobs,
|
||||
selected_token_ranks=None, # TODO
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user