Optimize _get_ranks in Sampler (#3623)

This commit is contained in:
Antoni Baum 2024-03-25 16:03:02 -07:00 committed by GitHub
parent 64172a976c
commit 3a243095e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -506,22 +506,23 @@ def _sample(
# sampling_tensors)
def _get_ranks(x: torch.Tensor, indices: List[int]) -> torch.Tensor:
def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
"""
This function calculates the ranks of the chosen tokens in a logprob tensor.
Args:
x (torch.Tensor): 2D logprob tensor of shape (N, M)
where N is the no. of tokens and M is the vocab dim.
indices (List[int]): List of chosen token indices.
indices (torch.Tensor): List of chosen token indices.
Returns:
torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
Each element in the returned tensor represents the rank
of the chosen token in the input logprob tensor.
"""
vals = x[range(len(x)), indices]
return (x > vals[:, None]).long().sum(1) + 1
vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
indices]
return (x > vals[:, None]).long().sum(1).add_(1)
def _get_logprobs(
@ -561,12 +562,21 @@ def _get_logprobs(
sample_idx += num_parent_seqs
assert sample_idx == logprobs.size(0)
batched_logprobs_query_seq_indices_gpu = torch.tensor(
batched_logprobs_query_seq_indices, device=logprobs.device)
batched_logprobs_query_token_indices_gpu = torch.tensor(
batched_logprobs_query_token_indices, device=logprobs.device)
# Batched query for logprobs of selected token
batched_logprobs_query_result = logprobs[[
batched_logprobs_query_seq_indices,
batched_logprobs_query_token_indices
batched_logprobs_query_seq_indices_gpu,
batched_logprobs_query_token_indices_gpu
]]
batched_ranks_query_result = _get_ranks(
logprobs[batched_logprobs_query_seq_indices_gpu],
batched_logprobs_query_token_indices_gpu)
# Batched query for logprobs of topk tokens
if largest_num_logprobs > 0:
top_logprobs, top_token_ids = torch.topk(logprobs,
@ -578,10 +588,7 @@ def _get_logprobs(
top_logprobs, top_token_ids = None, None
batched_logprobs_query_result = batched_logprobs_query_result.cpu()
batched_ranks_query_result = _get_ranks(
logprobs[batched_logprobs_query_seq_indices],
batched_logprobs_query_token_indices)
batched_ranks_query_result = batched_ranks_query_result.cpu()
# Gather results
result_prompt_logprobs: List[Optional[PromptLogprobs]] = []