diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 008337be4ebf8..a59a1e7e6afcf 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -12,11 +12,11 @@ import torch class LogprobsLists(NamedTuple): # [num_reqs, max_num_logprobs + 1] - logprob_token_ids: list[list[int]] + logprob_token_ids: np.ndarray # [num_reqs, max_num_logprobs + 1] - logprobs: list[list[float]] + logprobs: np.ndarray # [num_reqs] - sampled_token_ranks: list[int] + sampled_token_ranks: np.ndarray def slice(self, start: int, end: int): return LogprobsLists( @@ -37,9 +37,9 @@ class LogprobsTensors(NamedTuple): def tolists(self): return LogprobsLists( - self.logprob_token_ids.tolist(), - self.logprobs.tolist(), - self.selected_token_ranks.tolist(), + self.logprob_token_ids.cpu().numpy(), + self.logprobs.cpu().numpy(), + self.selected_token_ranks.cpu().numpy(), ) @staticmethod diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index d591b3285297e..8c5ad14c40758 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -376,11 +376,14 @@ class GPUModelRunner: sampler_output = self.sample(logits, input_batch) sampled_token_ids_np, num_sampled_tokens = self.postprocess( sampler_output, input_batch) + logprobs = None + if sampler_output.logprobs_tensors is not None: + logprobs = sampler_output.logprobs_tensors.tolists() return ModelRunnerOutput( req_ids=input_batch.req_ids, sampled_token_ids=sampled_token_ids_np, num_sampled_tokens=num_sampled_tokens, - logprobs=sampler_output.logprobs_tensors, + logprobs=logprobs, prompt_logprobs_dict={}, pooler_output=[], kv_connector_output=None, diff --git a/vllm/v1/worker/gpu/sampler.py b/vllm/v1/worker/gpu/sampler.py index c7c7c3c32ad2f..6a2856b1b1255 100644 --- a/vllm/v1/worker/gpu/sampler.py +++ b/vllm/v1/worker/gpu/sampler.py @@ -193,7 +193,7 @@ def gumbel_sample( @triton.jit -def _topk_logprobs_kernel( +def _topk_log_softmax_kernel( output_ptr, logits_ptr, logits_stride, @@ -232,6 +232,31 @@ def _topk_logprobs_kernel( tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask) +@triton.jit +def _ranks_kernel( + output_ptr, + logits_ptr, + logits_stride, + token_ids_ptr, + vocab_size, + BLOCK_SIZE: tl.constexpr, +): + req_idx = tl.program_id(0) + row_ptr = logits_ptr + req_idx * logits_stride + + token_id = tl.load(token_ids_ptr + req_idx) + x = tl.load(row_ptr + token_id) + + n = 0 + for i in range(0, vocab_size, BLOCK_SIZE): + block = i + tl.arange(0, BLOCK_SIZE) + l = tl.load(row_ptr + block, + mask=block < vocab_size, + other=float("-inf")) + n += tl.sum((l > x).to(tl.int32)) + tl.store(output_ptr + req_idx, n) + + def compute_logprobs( logits: torch.Tensor, num_logprobs: int, @@ -255,19 +280,32 @@ def compute_logprobs( dtype=torch.float32, device=logits.device, ) - BLOCK_SIZE = 1024 - _topk_logprobs_kernel[(batch_size, )]( + _topk_log_softmax_kernel[(batch_size, )]( logprobs, logits, logits.stride(0), logprob_token_ids, num_logprobs + 1, vocab_size, - BLOCK_SIZE=BLOCK_SIZE, + BLOCK_SIZE=1024, PADDED_TOPK=triton.next_power_of_2(num_logprobs + 1), ) + + token_ranks = torch.empty( + batch_size, + dtype=torch.int64, + device=logits.device, + ) + _ranks_kernel[(batch_size, )]( + token_ranks, + logits, + logits.stride(0), + sampled_token_ids, + vocab_size, + BLOCK_SIZE=8192, + ) return LogprobsTensors( logprob_token_ids=logprob_token_ids, logprobs=logprobs, - selected_token_ranks=None, # TODO + selected_token_ranks=token_ranks, ) diff --git a/vllm/v1/worker/gpu/states.py b/vllm/v1/worker/gpu/states.py index 25d4bf808cedf..1d315c9fee205 100644 --- a/vllm/v1/worker/gpu/states.py +++ b/vllm/v1/worker/gpu/states.py @@ -160,7 +160,7 @@ class RequestState: seeds = self.seeds.copy_np_to_gpu(seeds) num_logprobs = self.num_logprobs[idx_mapping] - max_num_logprobs = np.max(num_logprobs) + max_num_logprobs = int(np.max(num_logprobs)) if max_num_logprobs == -1: max_num_logprobs = None