mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-29 23:05:17 +08:00
change the timing of sorting logits (#1309)
This commit is contained in:
parent
ac5cf86aa6
commit
91fce82c6f
@ -102,30 +102,24 @@ def _prune_hidden_states(
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
last_token_indices = {t: [] for t in SamplingType}
|
||||
last_token_indices = []
|
||||
start_idx = 0
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
seq_ids, sampling_params = seq_group
|
||||
sampling_type = sampling_params.sampling_type
|
||||
seq_ids, _ = seq_group
|
||||
if i < input_metadata.num_prompts:
|
||||
assert len(seq_ids) == 1, "Prompt input should have only one seq."
|
||||
prompt_len = input_metadata.prompt_lens[i]
|
||||
last_token_indices[sampling_type].append(start_idx + prompt_len -
|
||||
1)
|
||||
last_token_indices.append(start_idx + prompt_len - 1)
|
||||
start_idx += prompt_len
|
||||
else:
|
||||
num_seqs = len(seq_ids)
|
||||
last_token_indices[sampling_type].extend(
|
||||
range(start_idx, start_idx + num_seqs))
|
||||
last_token_indices.extend(range(start_idx, start_idx + num_seqs))
|
||||
start_idx += num_seqs
|
||||
|
||||
all_last_token_indices = []
|
||||
for sampling_type in SamplingType:
|
||||
all_last_token_indices.extend(last_token_indices[sampling_type])
|
||||
all_last_token_indices = torch.tensor(all_last_token_indices,
|
||||
dtype=torch.long,
|
||||
device=hidden_states.device)
|
||||
return hidden_states.index_select(0, all_last_token_indices)
|
||||
last_token_indices = torch.tensor(last_token_indices,
|
||||
dtype=torch.long,
|
||||
device=hidden_states.device)
|
||||
return hidden_states.index_select(0, last_token_indices)
|
||||
|
||||
|
||||
def _get_penalties(
|
||||
@ -424,27 +418,26 @@ def _sample(
|
||||
input_metadata: InputMetadata,
|
||||
) -> SamplerOutput:
|
||||
categorized_seq_group_ids = {t: [] for t in SamplingType}
|
||||
category_num_tokens = {t: 0 for t in SamplingType}
|
||||
start_idx = 0
|
||||
categorized_seq_ids = {t: [] for t in SamplingType}
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
seq_ids, sampling_params = seq_group
|
||||
sampling_type = sampling_params.sampling_type
|
||||
categorized_seq_group_ids[sampling_type].append(i)
|
||||
num_seqs = len(seq_ids)
|
||||
category_num_tokens[sampling_type] += num_seqs
|
||||
|
||||
categorized_seq_ids[sampling_type].extend(
|
||||
range(start_idx, start_idx + num_seqs))
|
||||
start_idx += num_seqs
|
||||
seq_outputs_dict: Dict[int, List[SequenceOutputs]] = {}
|
||||
category_start_idx = 0
|
||||
for sampling_type in SamplingType:
|
||||
seq_group_ids = categorized_seq_group_ids[sampling_type]
|
||||
seq_groups = [input_metadata.seq_groups[i] for i in seq_group_ids]
|
||||
is_prompts = [i < input_metadata.num_prompts for i in seq_group_ids]
|
||||
num_tokens = category_num_tokens[sampling_type]
|
||||
num_tokens = len(categorized_seq_ids[sampling_type])
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
category_logprobs = logprobs[category_start_idx:category_start_idx +
|
||||
num_tokens]
|
||||
category_probs = probs[category_start_idx:category_start_idx +
|
||||
num_tokens]
|
||||
category_logprobs = logprobs[categorized_seq_ids[sampling_type]]
|
||||
category_probs = probs[categorized_seq_ids[sampling_type]]
|
||||
if sampling_type == SamplingType.GREEDY:
|
||||
sample_results = _greedy_sample(seq_groups, category_logprobs)
|
||||
elif sampling_type == SamplingType.RANDOM:
|
||||
@ -497,6 +490,5 @@ def _sample(
|
||||
sample_idx += num_parent_seqs
|
||||
result_idx += num_results
|
||||
assert sample_idx == num_tokens
|
||||
category_start_idx += num_tokens
|
||||
|
||||
return [seq_outputs_dict[i] for i in range(len(input_metadata.seq_groups))]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user