mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-30 15:31:49 +08:00
Optimize _get_ranks in Sampler (#3623)
This commit is contained in:
parent
64172a976c
commit
3a243095e5
@ -506,22 +506,23 @@ def _sample(
|
||||
# sampling_tensors)
|
||||
|
||||
|
||||
def _get_ranks(x: torch.Tensor, indices: List[int]) -> torch.Tensor:
|
||||
def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
This function calculates the ranks of the chosen tokens in a logprob tensor.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): 2D logprob tensor of shape (N, M)
|
||||
where N is the no. of tokens and M is the vocab dim.
|
||||
indices (List[int]): List of chosen token indices.
|
||||
indices (torch.Tensor): List of chosen token indices.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
|
||||
Each element in the returned tensor represents the rank
|
||||
of the chosen token in the input logprob tensor.
|
||||
"""
|
||||
vals = x[range(len(x)), indices]
|
||||
return (x > vals[:, None]).long().sum(1) + 1
|
||||
vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
|
||||
indices]
|
||||
return (x > vals[:, None]).long().sum(1).add_(1)
|
||||
|
||||
|
||||
def _get_logprobs(
|
||||
@ -561,12 +562,21 @@ def _get_logprobs(
|
||||
sample_idx += num_parent_seqs
|
||||
assert sample_idx == logprobs.size(0)
|
||||
|
||||
batched_logprobs_query_seq_indices_gpu = torch.tensor(
|
||||
batched_logprobs_query_seq_indices, device=logprobs.device)
|
||||
batched_logprobs_query_token_indices_gpu = torch.tensor(
|
||||
batched_logprobs_query_token_indices, device=logprobs.device)
|
||||
|
||||
# Batched query for logprobs of selected token
|
||||
batched_logprobs_query_result = logprobs[[
|
||||
batched_logprobs_query_seq_indices,
|
||||
batched_logprobs_query_token_indices
|
||||
batched_logprobs_query_seq_indices_gpu,
|
||||
batched_logprobs_query_token_indices_gpu
|
||||
]]
|
||||
|
||||
batched_ranks_query_result = _get_ranks(
|
||||
logprobs[batched_logprobs_query_seq_indices_gpu],
|
||||
batched_logprobs_query_token_indices_gpu)
|
||||
|
||||
# Batched query for logprobs of topk tokens
|
||||
if largest_num_logprobs > 0:
|
||||
top_logprobs, top_token_ids = torch.topk(logprobs,
|
||||
@ -578,10 +588,7 @@ def _get_logprobs(
|
||||
top_logprobs, top_token_ids = None, None
|
||||
|
||||
batched_logprobs_query_result = batched_logprobs_query_result.cpu()
|
||||
|
||||
batched_ranks_query_result = _get_ranks(
|
||||
logprobs[batched_logprobs_query_seq_indices],
|
||||
batched_logprobs_query_token_indices)
|
||||
batched_ranks_query_result = batched_ranks_query_result.cpu()
|
||||
|
||||
# Gather results
|
||||
result_prompt_logprobs: List[Optional[PromptLogprobs]] = []
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user