From 91fce82c6f3e3efd705faa0edd3aa64d328c3c77 Mon Sep 17 00:00:00 2001 From: yhlskt23 <146050887+yhlskt23@users.noreply.github.com> Date: Wed, 11 Oct 2023 11:37:42 +0900 Subject: [PATCH] change the timing of sorting logits (#1309) --- vllm/model_executor/layers/sampler.py | 40 +++++++++++---------------- 1 file changed, 16 insertions(+), 24 deletions(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 76442eae680d2..a9f036fa2f24f 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -102,30 +102,24 @@ def _prune_hidden_states( hidden_states: torch.Tensor, input_metadata: InputMetadata, ) -> torch.Tensor: - last_token_indices = {t: [] for t in SamplingType} + last_token_indices = [] start_idx = 0 for i, seq_group in enumerate(input_metadata.seq_groups): - seq_ids, sampling_params = seq_group - sampling_type = sampling_params.sampling_type + seq_ids, _ = seq_group if i < input_metadata.num_prompts: assert len(seq_ids) == 1, "Prompt input should have only one seq." prompt_len = input_metadata.prompt_lens[i] - last_token_indices[sampling_type].append(start_idx + prompt_len - - 1) + last_token_indices.append(start_idx + prompt_len - 1) start_idx += prompt_len else: num_seqs = len(seq_ids) - last_token_indices[sampling_type].extend( - range(start_idx, start_idx + num_seqs)) + last_token_indices.extend(range(start_idx, start_idx + num_seqs)) start_idx += num_seqs - all_last_token_indices = [] - for sampling_type in SamplingType: - all_last_token_indices.extend(last_token_indices[sampling_type]) - all_last_token_indices = torch.tensor(all_last_token_indices, - dtype=torch.long, - device=hidden_states.device) - return hidden_states.index_select(0, all_last_token_indices) + last_token_indices = torch.tensor(last_token_indices, + dtype=torch.long, + device=hidden_states.device) + return hidden_states.index_select(0, last_token_indices) def _get_penalties( @@ -424,27 +418,26 @@ def _sample( input_metadata: InputMetadata, ) -> SamplerOutput: categorized_seq_group_ids = {t: [] for t in SamplingType} - category_num_tokens = {t: 0 for t in SamplingType} + start_idx = 0 + categorized_seq_ids = {t: [] for t in SamplingType} for i, seq_group in enumerate(input_metadata.seq_groups): seq_ids, sampling_params = seq_group sampling_type = sampling_params.sampling_type categorized_seq_group_ids[sampling_type].append(i) num_seqs = len(seq_ids) - category_num_tokens[sampling_type] += num_seqs - + categorized_seq_ids[sampling_type].extend( + range(start_idx, start_idx + num_seqs)) + start_idx += num_seqs seq_outputs_dict: Dict[int, List[SequenceOutputs]] = {} - category_start_idx = 0 for sampling_type in SamplingType: seq_group_ids = categorized_seq_group_ids[sampling_type] seq_groups = [input_metadata.seq_groups[i] for i in seq_group_ids] is_prompts = [i < input_metadata.num_prompts for i in seq_group_ids] - num_tokens = category_num_tokens[sampling_type] + num_tokens = len(categorized_seq_ids[sampling_type]) if num_tokens == 0: continue - category_logprobs = logprobs[category_start_idx:category_start_idx + - num_tokens] - category_probs = probs[category_start_idx:category_start_idx + - num_tokens] + category_logprobs = logprobs[categorized_seq_ids[sampling_type]] + category_probs = probs[categorized_seq_ids[sampling_type]] if sampling_type == SamplingType.GREEDY: sample_results = _greedy_sample(seq_groups, category_logprobs) elif sampling_type == SamplingType.RANDOM: @@ -497,6 +490,5 @@ def _sample( sample_idx += num_parent_seqs result_idx += num_results assert sample_idx == num_tokens - category_start_idx += num_tokens return [seq_outputs_dict[i] for i in range(len(input_metadata.seq_groups))]