mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-06 08:22:15 +08:00
Optimize _get_ranks in Sampler (#3623)
This commit is contained in:
parent
64172a976c
commit
3a243095e5
@ -506,22 +506,23 @@ def _sample(
|
|||||||
# sampling_tensors)
|
# sampling_tensors)
|
||||||
|
|
||||||
|
|
||||||
def _get_ranks(x: torch.Tensor, indices: List[int]) -> torch.Tensor:
|
def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
This function calculates the ranks of the chosen tokens in a logprob tensor.
|
This function calculates the ranks of the chosen tokens in a logprob tensor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (torch.Tensor): 2D logprob tensor of shape (N, M)
|
x (torch.Tensor): 2D logprob tensor of shape (N, M)
|
||||||
where N is the no. of tokens and M is the vocab dim.
|
where N is the no. of tokens and M is the vocab dim.
|
||||||
indices (List[int]): List of chosen token indices.
|
indices (torch.Tensor): List of chosen token indices.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
|
torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
|
||||||
Each element in the returned tensor represents the rank
|
Each element in the returned tensor represents the rank
|
||||||
of the chosen token in the input logprob tensor.
|
of the chosen token in the input logprob tensor.
|
||||||
"""
|
"""
|
||||||
vals = x[range(len(x)), indices]
|
vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
|
||||||
return (x > vals[:, None]).long().sum(1) + 1
|
indices]
|
||||||
|
return (x > vals[:, None]).long().sum(1).add_(1)
|
||||||
|
|
||||||
|
|
||||||
def _get_logprobs(
|
def _get_logprobs(
|
||||||
@ -561,12 +562,21 @@ def _get_logprobs(
|
|||||||
sample_idx += num_parent_seqs
|
sample_idx += num_parent_seqs
|
||||||
assert sample_idx == logprobs.size(0)
|
assert sample_idx == logprobs.size(0)
|
||||||
|
|
||||||
|
batched_logprobs_query_seq_indices_gpu = torch.tensor(
|
||||||
|
batched_logprobs_query_seq_indices, device=logprobs.device)
|
||||||
|
batched_logprobs_query_token_indices_gpu = torch.tensor(
|
||||||
|
batched_logprobs_query_token_indices, device=logprobs.device)
|
||||||
|
|
||||||
# Batched query for logprobs of selected token
|
# Batched query for logprobs of selected token
|
||||||
batched_logprobs_query_result = logprobs[[
|
batched_logprobs_query_result = logprobs[[
|
||||||
batched_logprobs_query_seq_indices,
|
batched_logprobs_query_seq_indices_gpu,
|
||||||
batched_logprobs_query_token_indices
|
batched_logprobs_query_token_indices_gpu
|
||||||
]]
|
]]
|
||||||
|
|
||||||
|
batched_ranks_query_result = _get_ranks(
|
||||||
|
logprobs[batched_logprobs_query_seq_indices_gpu],
|
||||||
|
batched_logprobs_query_token_indices_gpu)
|
||||||
|
|
||||||
# Batched query for logprobs of topk tokens
|
# Batched query for logprobs of topk tokens
|
||||||
if largest_num_logprobs > 0:
|
if largest_num_logprobs > 0:
|
||||||
top_logprobs, top_token_ids = torch.topk(logprobs,
|
top_logprobs, top_token_ids = torch.topk(logprobs,
|
||||||
@ -578,10 +588,7 @@ def _get_logprobs(
|
|||||||
top_logprobs, top_token_ids = None, None
|
top_logprobs, top_token_ids = None, None
|
||||||
|
|
||||||
batched_logprobs_query_result = batched_logprobs_query_result.cpu()
|
batched_logprobs_query_result = batched_logprobs_query_result.cpu()
|
||||||
|
batched_ranks_query_result = batched_ranks_query_result.cpu()
|
||||||
batched_ranks_query_result = _get_ranks(
|
|
||||||
logprobs[batched_logprobs_query_seq_indices],
|
|
||||||
batched_logprobs_query_token_indices)
|
|
||||||
|
|
||||||
# Gather results
|
# Gather results
|
||||||
result_prompt_logprobs: List[Optional[PromptLogprobs]] = []
|
result_prompt_logprobs: List[Optional[PromptLogprobs]] = []
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user