From 3a243095e5e7b655b63ab08fbd5936cb40850415 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Mon, 25 Mar 2024 16:03:02 -0700 Subject: [PATCH] Optimize `_get_ranks` in Sampler (#3623) --- vllm/model_executor/layers/sampler.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index d07527304962d..06135192c192e 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -506,22 +506,23 @@ def _sample( # sampling_tensors) -def _get_ranks(x: torch.Tensor, indices: List[int]) -> torch.Tensor: +def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: """ This function calculates the ranks of the chosen tokens in a logprob tensor. Args: x (torch.Tensor): 2D logprob tensor of shape (N, M) where N is the no. of tokens and M is the vocab dim. - indices (List[int]): List of chosen token indices. + indices (torch.Tensor): List of chosen token indices. Returns: torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens. Each element in the returned tensor represents the rank of the chosen token in the input logprob tensor. """ - vals = x[range(len(x)), indices] - return (x > vals[:, None]).long().sum(1) + 1 + vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype), + indices] + return (x > vals[:, None]).long().sum(1).add_(1) def _get_logprobs( @@ -561,12 +562,21 @@ def _get_logprobs( sample_idx += num_parent_seqs assert sample_idx == logprobs.size(0) + batched_logprobs_query_seq_indices_gpu = torch.tensor( + batched_logprobs_query_seq_indices, device=logprobs.device) + batched_logprobs_query_token_indices_gpu = torch.tensor( + batched_logprobs_query_token_indices, device=logprobs.device) + # Batched query for logprobs of selected token batched_logprobs_query_result = logprobs[[ - batched_logprobs_query_seq_indices, - batched_logprobs_query_token_indices + batched_logprobs_query_seq_indices_gpu, + batched_logprobs_query_token_indices_gpu ]] + batched_ranks_query_result = _get_ranks( + logprobs[batched_logprobs_query_seq_indices_gpu], + batched_logprobs_query_token_indices_gpu) + # Batched query for logprobs of topk tokens if largest_num_logprobs > 0: top_logprobs, top_token_ids = torch.topk(logprobs, @@ -578,10 +588,7 @@ def _get_logprobs( top_logprobs, top_token_ids = None, None batched_logprobs_query_result = batched_logprobs_query_result.cpu() - - batched_ranks_query_result = _get_ranks( - logprobs[batched_logprobs_query_seq_indices], - batched_logprobs_query_token_indices) + batched_ranks_query_result = batched_ranks_query_result.cpu() # Gather results result_prompt_logprobs: List[Optional[PromptLogprobs]] = []