Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-09-18 17:23:02 -07:00
parent 31619ff412
commit b9c74487d2
4 changed files with 54 additions and 13 deletions

View File

@ -12,11 +12,11 @@ import torch
class LogprobsLists(NamedTuple):
# [num_reqs, max_num_logprobs + 1]
logprob_token_ids: list[list[int]]
logprob_token_ids: np.ndarray
# [num_reqs, max_num_logprobs + 1]
logprobs: list[list[float]]
logprobs: np.ndarray
# [num_reqs]
sampled_token_ranks: list[int]
sampled_token_ranks: np.ndarray
def slice(self, start: int, end: int):
return LogprobsLists(
@ -37,9 +37,9 @@ class LogprobsTensors(NamedTuple):
def tolists(self):
return LogprobsLists(
self.logprob_token_ids.tolist(),
self.logprobs.tolist(),
self.selected_token_ranks.tolist(),
self.logprob_token_ids.cpu().numpy(),
self.logprobs.cpu().numpy(),
self.selected_token_ranks.cpu().numpy(),
)
@staticmethod

View File

@ -376,11 +376,14 @@ class GPUModelRunner:
sampler_output = self.sample(logits, input_batch)
sampled_token_ids_np, num_sampled_tokens = self.postprocess(
sampler_output, input_batch)
logprobs = None
if sampler_output.logprobs_tensors is not None:
logprobs = sampler_output.logprobs_tensors.tolists()
return ModelRunnerOutput(
req_ids=input_batch.req_ids,
sampled_token_ids=sampled_token_ids_np,
num_sampled_tokens=num_sampled_tokens,
logprobs=sampler_output.logprobs_tensors,
logprobs=logprobs,
prompt_logprobs_dict={},
pooler_output=[],
kv_connector_output=None,

View File

@ -193,7 +193,7 @@ def gumbel_sample(
@triton.jit
def _topk_logprobs_kernel(
def _topk_log_softmax_kernel(
output_ptr,
logits_ptr,
logits_stride,
@ -232,6 +232,31 @@ def _topk_logprobs_kernel(
tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask)
@triton.jit
def _ranks_kernel(
output_ptr,
logits_ptr,
logits_stride,
token_ids_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
row_ptr = logits_ptr + req_idx * logits_stride
token_id = tl.load(token_ids_ptr + req_idx)
x = tl.load(row_ptr + token_id)
n = 0
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
l = tl.load(row_ptr + block,
mask=block < vocab_size,
other=float("-inf"))
n += tl.sum((l > x).to(tl.int32))
tl.store(output_ptr + req_idx, n)
def compute_logprobs(
logits: torch.Tensor,
num_logprobs: int,
@ -255,19 +280,32 @@ def compute_logprobs(
dtype=torch.float32,
device=logits.device,
)
BLOCK_SIZE = 1024
_topk_logprobs_kernel[(batch_size, )](
_topk_log_softmax_kernel[(batch_size, )](
logprobs,
logits,
logits.stride(0),
logprob_token_ids,
num_logprobs + 1,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
BLOCK_SIZE=1024,
PADDED_TOPK=triton.next_power_of_2(num_logprobs + 1),
)
token_ranks = torch.empty(
batch_size,
dtype=torch.int64,
device=logits.device,
)
_ranks_kernel[(batch_size, )](
token_ranks,
logits,
logits.stride(0),
sampled_token_ids,
vocab_size,
BLOCK_SIZE=8192,
)
return LogprobsTensors(
logprob_token_ids=logprob_token_ids,
logprobs=logprobs,
selected_token_ranks=None, # TODO
selected_token_ranks=token_ranks,
)

View File

@ -160,7 +160,7 @@ class RequestState:
seeds = self.seeds.copy_np_to_gpu(seeds)
num_logprobs = self.num_logprobs[idx_mapping]
max_num_logprobs = np.max(num_logprobs)
max_num_logprobs = int(np.max(num_logprobs))
if max_num_logprobs == -1:
max_num_logprobs = None