mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-08 12:27:05 +08:00
logprobs
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
31619ff412
commit
b9c74487d2
@ -12,11 +12,11 @@ import torch
|
||||
class LogprobsLists(NamedTuple):
|
||||
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
logprob_token_ids: list[list[int]]
|
||||
logprob_token_ids: np.ndarray
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
logprobs: list[list[float]]
|
||||
logprobs: np.ndarray
|
||||
# [num_reqs]
|
||||
sampled_token_ranks: list[int]
|
||||
sampled_token_ranks: np.ndarray
|
||||
|
||||
def slice(self, start: int, end: int):
|
||||
return LogprobsLists(
|
||||
@ -37,9 +37,9 @@ class LogprobsTensors(NamedTuple):
|
||||
|
||||
def tolists(self):
|
||||
return LogprobsLists(
|
||||
self.logprob_token_ids.tolist(),
|
||||
self.logprobs.tolist(),
|
||||
self.selected_token_ranks.tolist(),
|
||||
self.logprob_token_ids.cpu().numpy(),
|
||||
self.logprobs.cpu().numpy(),
|
||||
self.selected_token_ranks.cpu().numpy(),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -376,11 +376,14 @@ class GPUModelRunner:
|
||||
sampler_output = self.sample(logits, input_batch)
|
||||
sampled_token_ids_np, num_sampled_tokens = self.postprocess(
|
||||
sampler_output, input_batch)
|
||||
logprobs = None
|
||||
if sampler_output.logprobs_tensors is not None:
|
||||
logprobs = sampler_output.logprobs_tensors.tolists()
|
||||
return ModelRunnerOutput(
|
||||
req_ids=input_batch.req_ids,
|
||||
sampled_token_ids=sampled_token_ids_np,
|
||||
num_sampled_tokens=num_sampled_tokens,
|
||||
logprobs=sampler_output.logprobs_tensors,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
kv_connector_output=None,
|
||||
|
||||
@ -193,7 +193,7 @@ def gumbel_sample(
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _topk_logprobs_kernel(
|
||||
def _topk_log_softmax_kernel(
|
||||
output_ptr,
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
@ -232,6 +232,31 @@ def _topk_logprobs_kernel(
|
||||
tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _ranks_kernel(
|
||||
output_ptr,
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
token_ids_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
row_ptr = logits_ptr + req_idx * logits_stride
|
||||
|
||||
token_id = tl.load(token_ids_ptr + req_idx)
|
||||
x = tl.load(row_ptr + token_id)
|
||||
|
||||
n = 0
|
||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
l = tl.load(row_ptr + block,
|
||||
mask=block < vocab_size,
|
||||
other=float("-inf"))
|
||||
n += tl.sum((l > x).to(tl.int32))
|
||||
tl.store(output_ptr + req_idx, n)
|
||||
|
||||
|
||||
def compute_logprobs(
|
||||
logits: torch.Tensor,
|
||||
num_logprobs: int,
|
||||
@ -255,19 +280,32 @@ def compute_logprobs(
|
||||
dtype=torch.float32,
|
||||
device=logits.device,
|
||||
)
|
||||
BLOCK_SIZE = 1024
|
||||
_topk_logprobs_kernel[(batch_size, )](
|
||||
_topk_log_softmax_kernel[(batch_size, )](
|
||||
logprobs,
|
||||
logits,
|
||||
logits.stride(0),
|
||||
logprob_token_ids,
|
||||
num_logprobs + 1,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
BLOCK_SIZE=1024,
|
||||
PADDED_TOPK=triton.next_power_of_2(num_logprobs + 1),
|
||||
)
|
||||
|
||||
token_ranks = torch.empty(
|
||||
batch_size,
|
||||
dtype=torch.int64,
|
||||
device=logits.device,
|
||||
)
|
||||
_ranks_kernel[(batch_size, )](
|
||||
token_ranks,
|
||||
logits,
|
||||
logits.stride(0),
|
||||
sampled_token_ids,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=8192,
|
||||
)
|
||||
return LogprobsTensors(
|
||||
logprob_token_ids=logprob_token_ids,
|
||||
logprobs=logprobs,
|
||||
selected_token_ranks=None, # TODO
|
||||
selected_token_ranks=token_ranks,
|
||||
)
|
||||
|
||||
@ -160,7 +160,7 @@ class RequestState:
|
||||
seeds = self.seeds.copy_np_to_gpu(seeds)
|
||||
|
||||
num_logprobs = self.num_logprobs[idx_mapping]
|
||||
max_num_logprobs = np.max(num_logprobs)
|
||||
max_num_logprobs = int(np.max(num_logprobs))
|
||||
if max_num_logprobs == -1:
|
||||
max_num_logprobs = None
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user