Optimize _get_ranks in Sampler (#3623)

This commit is contained in:
Antoni Baum 2024-03-25 16:03:02 -07:00 committed by GitHub
parent 64172a976c
commit 3a243095e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -506,22 +506,23 @@ def _sample(
# sampling_tensors) # sampling_tensors)
def _get_ranks(x: torch.Tensor, indices: List[int]) -> torch.Tensor: def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
""" """
This function calculates the ranks of the chosen tokens in a logprob tensor. This function calculates the ranks of the chosen tokens in a logprob tensor.
Args: Args:
x (torch.Tensor): 2D logprob tensor of shape (N, M) x (torch.Tensor): 2D logprob tensor of shape (N, M)
where N is the no. of tokens and M is the vocab dim. where N is the no. of tokens and M is the vocab dim.
indices (List[int]): List of chosen token indices. indices (torch.Tensor): List of chosen token indices.
Returns: Returns:
torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens. torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
Each element in the returned tensor represents the rank Each element in the returned tensor represents the rank
of the chosen token in the input logprob tensor. of the chosen token in the input logprob tensor.
""" """
vals = x[range(len(x)), indices] vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
return (x > vals[:, None]).long().sum(1) + 1 indices]
return (x > vals[:, None]).long().sum(1).add_(1)
def _get_logprobs( def _get_logprobs(
@ -561,12 +562,21 @@ def _get_logprobs(
sample_idx += num_parent_seqs sample_idx += num_parent_seqs
assert sample_idx == logprobs.size(0) assert sample_idx == logprobs.size(0)
batched_logprobs_query_seq_indices_gpu = torch.tensor(
batched_logprobs_query_seq_indices, device=logprobs.device)
batched_logprobs_query_token_indices_gpu = torch.tensor(
batched_logprobs_query_token_indices, device=logprobs.device)
# Batched query for logprobs of selected token # Batched query for logprobs of selected token
batched_logprobs_query_result = logprobs[[ batched_logprobs_query_result = logprobs[[
batched_logprobs_query_seq_indices, batched_logprobs_query_seq_indices_gpu,
batched_logprobs_query_token_indices batched_logprobs_query_token_indices_gpu
]] ]]
batched_ranks_query_result = _get_ranks(
logprobs[batched_logprobs_query_seq_indices_gpu],
batched_logprobs_query_token_indices_gpu)
# Batched query for logprobs of topk tokens # Batched query for logprobs of topk tokens
if largest_num_logprobs > 0: if largest_num_logprobs > 0:
top_logprobs, top_token_ids = torch.topk(logprobs, top_logprobs, top_token_ids = torch.topk(logprobs,
@ -578,10 +588,7 @@ def _get_logprobs(
top_logprobs, top_token_ids = None, None top_logprobs, top_token_ids = None, None
batched_logprobs_query_result = batched_logprobs_query_result.cpu() batched_logprobs_query_result = batched_logprobs_query_result.cpu()
batched_ranks_query_result = batched_ranks_query_result.cpu()
batched_ranks_query_result = _get_ranks(
logprobs[batched_logprobs_query_seq_indices],
batched_logprobs_query_token_indices)
# Gather results # Gather results
result_prompt_logprobs: List[Optional[PromptLogprobs]] = [] result_prompt_logprobs: List[Optional[PromptLogprobs]] = []