change the timing of sorting logits (#1309)

This commit is contained in:
yhlskt23 2023-10-11 11:37:42 +09:00 committed by GitHub
parent ac5cf86aa6
commit 91fce82c6f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -102,30 +102,24 @@ def _prune_hidden_states(
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
last_token_indices = {t: [] for t in SamplingType}
last_token_indices = []
start_idx = 0
for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group
sampling_type = sampling_params.sampling_type
seq_ids, _ = seq_group
if i < input_metadata.num_prompts:
assert len(seq_ids) == 1, "Prompt input should have only one seq."
prompt_len = input_metadata.prompt_lens[i]
last_token_indices[sampling_type].append(start_idx + prompt_len -
1)
last_token_indices.append(start_idx + prompt_len - 1)
start_idx += prompt_len
else:
num_seqs = len(seq_ids)
last_token_indices[sampling_type].extend(
range(start_idx, start_idx + num_seqs))
last_token_indices.extend(range(start_idx, start_idx + num_seqs))
start_idx += num_seqs
all_last_token_indices = []
for sampling_type in SamplingType:
all_last_token_indices.extend(last_token_indices[sampling_type])
all_last_token_indices = torch.tensor(all_last_token_indices,
dtype=torch.long,
device=hidden_states.device)
return hidden_states.index_select(0, all_last_token_indices)
last_token_indices = torch.tensor(last_token_indices,
dtype=torch.long,
device=hidden_states.device)
return hidden_states.index_select(0, last_token_indices)
def _get_penalties(
@ -424,27 +418,26 @@ def _sample(
input_metadata: InputMetadata,
) -> SamplerOutput:
categorized_seq_group_ids = {t: [] for t in SamplingType}
category_num_tokens = {t: 0 for t in SamplingType}
start_idx = 0
categorized_seq_ids = {t: [] for t in SamplingType}
for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group
sampling_type = sampling_params.sampling_type
categorized_seq_group_ids[sampling_type].append(i)
num_seqs = len(seq_ids)
category_num_tokens[sampling_type] += num_seqs
categorized_seq_ids[sampling_type].extend(
range(start_idx, start_idx + num_seqs))
start_idx += num_seqs
seq_outputs_dict: Dict[int, List[SequenceOutputs]] = {}
category_start_idx = 0
for sampling_type in SamplingType:
seq_group_ids = categorized_seq_group_ids[sampling_type]
seq_groups = [input_metadata.seq_groups[i] for i in seq_group_ids]
is_prompts = [i < input_metadata.num_prompts for i in seq_group_ids]
num_tokens = category_num_tokens[sampling_type]
num_tokens = len(categorized_seq_ids[sampling_type])
if num_tokens == 0:
continue
category_logprobs = logprobs[category_start_idx:category_start_idx +
num_tokens]
category_probs = probs[category_start_idx:category_start_idx +
num_tokens]
category_logprobs = logprobs[categorized_seq_ids[sampling_type]]
category_probs = probs[categorized_seq_ids[sampling_type]]
if sampling_type == SamplingType.GREEDY:
sample_results = _greedy_sample(seq_groups, category_logprobs)
elif sampling_type == SamplingType.RANDOM:
@ -497,6 +490,5 @@ def _sample(
sample_idx += num_parent_seqs
result_idx += num_results
assert sample_idx == num_tokens
category_start_idx += num_tokens
return [seq_outputs_dict[i] for i in range(len(input_metadata.seq_groups))]