mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-23 22:17:54 +08:00
logprobs
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
31619ff412
commit
b9c74487d2
@ -12,11 +12,11 @@ import torch
|
|||||||
class LogprobsLists(NamedTuple):
|
class LogprobsLists(NamedTuple):
|
||||||
|
|
||||||
# [num_reqs, max_num_logprobs + 1]
|
# [num_reqs, max_num_logprobs + 1]
|
||||||
logprob_token_ids: list[list[int]]
|
logprob_token_ids: np.ndarray
|
||||||
# [num_reqs, max_num_logprobs + 1]
|
# [num_reqs, max_num_logprobs + 1]
|
||||||
logprobs: list[list[float]]
|
logprobs: np.ndarray
|
||||||
# [num_reqs]
|
# [num_reqs]
|
||||||
sampled_token_ranks: list[int]
|
sampled_token_ranks: np.ndarray
|
||||||
|
|
||||||
def slice(self, start: int, end: int):
|
def slice(self, start: int, end: int):
|
||||||
return LogprobsLists(
|
return LogprobsLists(
|
||||||
@ -37,9 +37,9 @@ class LogprobsTensors(NamedTuple):
|
|||||||
|
|
||||||
def tolists(self):
|
def tolists(self):
|
||||||
return LogprobsLists(
|
return LogprobsLists(
|
||||||
self.logprob_token_ids.tolist(),
|
self.logprob_token_ids.cpu().numpy(),
|
||||||
self.logprobs.tolist(),
|
self.logprobs.cpu().numpy(),
|
||||||
self.selected_token_ranks.tolist(),
|
self.selected_token_ranks.cpu().numpy(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@ -376,11 +376,14 @@ class GPUModelRunner:
|
|||||||
sampler_output = self.sample(logits, input_batch)
|
sampler_output = self.sample(logits, input_batch)
|
||||||
sampled_token_ids_np, num_sampled_tokens = self.postprocess(
|
sampled_token_ids_np, num_sampled_tokens = self.postprocess(
|
||||||
sampler_output, input_batch)
|
sampler_output, input_batch)
|
||||||
|
logprobs = None
|
||||||
|
if sampler_output.logprobs_tensors is not None:
|
||||||
|
logprobs = sampler_output.logprobs_tensors.tolists()
|
||||||
return ModelRunnerOutput(
|
return ModelRunnerOutput(
|
||||||
req_ids=input_batch.req_ids,
|
req_ids=input_batch.req_ids,
|
||||||
sampled_token_ids=sampled_token_ids_np,
|
sampled_token_ids=sampled_token_ids_np,
|
||||||
num_sampled_tokens=num_sampled_tokens,
|
num_sampled_tokens=num_sampled_tokens,
|
||||||
logprobs=sampler_output.logprobs_tensors,
|
logprobs=logprobs,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
pooler_output=[],
|
pooler_output=[],
|
||||||
kv_connector_output=None,
|
kv_connector_output=None,
|
||||||
|
|||||||
@ -193,7 +193,7 @@ def gumbel_sample(
|
|||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _topk_logprobs_kernel(
|
def _topk_log_softmax_kernel(
|
||||||
output_ptr,
|
output_ptr,
|
||||||
logits_ptr,
|
logits_ptr,
|
||||||
logits_stride,
|
logits_stride,
|
||||||
@ -232,6 +232,31 @@ def _topk_logprobs_kernel(
|
|||||||
tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask)
|
tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _ranks_kernel(
|
||||||
|
output_ptr,
|
||||||
|
logits_ptr,
|
||||||
|
logits_stride,
|
||||||
|
token_ids_ptr,
|
||||||
|
vocab_size,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
req_idx = tl.program_id(0)
|
||||||
|
row_ptr = logits_ptr + req_idx * logits_stride
|
||||||
|
|
||||||
|
token_id = tl.load(token_ids_ptr + req_idx)
|
||||||
|
x = tl.load(row_ptr + token_id)
|
||||||
|
|
||||||
|
n = 0
|
||||||
|
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||||
|
block = i + tl.arange(0, BLOCK_SIZE)
|
||||||
|
l = tl.load(row_ptr + block,
|
||||||
|
mask=block < vocab_size,
|
||||||
|
other=float("-inf"))
|
||||||
|
n += tl.sum((l > x).to(tl.int32))
|
||||||
|
tl.store(output_ptr + req_idx, n)
|
||||||
|
|
||||||
|
|
||||||
def compute_logprobs(
|
def compute_logprobs(
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
num_logprobs: int,
|
num_logprobs: int,
|
||||||
@ -255,19 +280,32 @@ def compute_logprobs(
|
|||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=logits.device,
|
device=logits.device,
|
||||||
)
|
)
|
||||||
BLOCK_SIZE = 1024
|
_topk_log_softmax_kernel[(batch_size, )](
|
||||||
_topk_logprobs_kernel[(batch_size, )](
|
|
||||||
logprobs,
|
logprobs,
|
||||||
logits,
|
logits,
|
||||||
logits.stride(0),
|
logits.stride(0),
|
||||||
logprob_token_ids,
|
logprob_token_ids,
|
||||||
num_logprobs + 1,
|
num_logprobs + 1,
|
||||||
vocab_size,
|
vocab_size,
|
||||||
BLOCK_SIZE=BLOCK_SIZE,
|
BLOCK_SIZE=1024,
|
||||||
PADDED_TOPK=triton.next_power_of_2(num_logprobs + 1),
|
PADDED_TOPK=triton.next_power_of_2(num_logprobs + 1),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
token_ranks = torch.empty(
|
||||||
|
batch_size,
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=logits.device,
|
||||||
|
)
|
||||||
|
_ranks_kernel[(batch_size, )](
|
||||||
|
token_ranks,
|
||||||
|
logits,
|
||||||
|
logits.stride(0),
|
||||||
|
sampled_token_ids,
|
||||||
|
vocab_size,
|
||||||
|
BLOCK_SIZE=8192,
|
||||||
|
)
|
||||||
return LogprobsTensors(
|
return LogprobsTensors(
|
||||||
logprob_token_ids=logprob_token_ids,
|
logprob_token_ids=logprob_token_ids,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
selected_token_ranks=None, # TODO
|
selected_token_ranks=token_ranks,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -160,7 +160,7 @@ class RequestState:
|
|||||||
seeds = self.seeds.copy_np_to_gpu(seeds)
|
seeds = self.seeds.copy_np_to_gpu(seeds)
|
||||||
|
|
||||||
num_logprobs = self.num_logprobs[idx_mapping]
|
num_logprobs = self.num_logprobs[idx_mapping]
|
||||||
max_num_logprobs = np.max(num_logprobs)
|
max_num_logprobs = int(np.max(num_logprobs))
|
||||||
if max_num_logprobs == -1:
|
if max_num_logprobs == -1:
|
||||||
max_num_logprobs = None
|
max_num_logprobs = None
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user