diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 5179b8d94ae1..a17dfcfbaf16 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -133,37 +133,22 @@ def _get_penalties( # Collect the presence and frequency penalties. presence_penalties: List[float] = [] frequency_penalties: List[float] = [] - for i, seq_group in enumerate(input_metadata.seq_groups): + for seq_group in input_metadata.seq_groups: seq_ids, sampling_params = seq_group p = sampling_params.presence_penalty f = sampling_params.frequency_penalty - if i < input_metadata.num_prompts: - # A prompt input. - presence_penalties.append(p) - frequency_penalties.append(f) - else: - # A generation token. - presence_penalties += [p] * len(seq_ids) - frequency_penalties += [f] * len(seq_ids) + presence_penalties += [p] * len(seq_ids) + frequency_penalties += [f] * len(seq_ids) return presence_penalties, frequency_penalties def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]: output_tokens: List[List[int]] = [] - for i, seq_group in enumerate(input_metadata.seq_groups): + for seq_group in input_metadata.seq_groups: seq_ids, _ = seq_group - if i < input_metadata.num_prompts: - # A prompt input. - # NOTE: While the prompt input usually has no output tokens, - # it may have output tokens in the case of recomputation. - seq_id = seq_ids[0] + for seq_id in seq_ids: seq_data = input_metadata.seq_data[seq_id] output_tokens.append(seq_data.output_token_ids) - else: - # A generation token. - for seq_id in seq_ids: - seq_data = input_metadata.seq_data[seq_id] - output_tokens.append(seq_data.output_token_ids) return output_tokens @@ -221,7 +206,7 @@ def _apply_penalties( def _get_temperatures(input_metadata: InputMetadata) -> List[float]: # Collect the temperatures for the logits. temperatures: List[float] = [] - for i, seq_group in enumerate(input_metadata.seq_groups): + for seq_group in input_metadata.seq_groups: seq_ids, sampling_params = seq_group temperature = sampling_params.temperature if temperature < _SAMPLING_EPS: @@ -229,13 +214,7 @@ def _get_temperatures(input_metadata: InputMetadata) -> List[float]: # (i.e., greedy sampling or beam search). # Set the temperature to 1 to avoid division by zero. temperature = 1.0 - - if i < input_metadata.num_prompts: - # A prompt input. - temperatures.append(temperature) - else: - # A generation token. - temperatures += [temperature] * len(seq_ids) + temperatures += [temperature] * len(seq_ids) return temperatures @@ -245,21 +224,15 @@ def _get_top_p_top_k( ) -> Tuple[List[float], List[int]]: top_ps: List[float] = [] top_ks: List[int] = [] - for i, seq_group in enumerate(input_metadata.seq_groups): + for seq_group in input_metadata.seq_groups: seq_ids, sampling_params = seq_group top_p = sampling_params.top_p # k should not be greater than the vocab size. top_k = min(sampling_params.top_k, vocab_size) # k=-1 means no truncation. top_k = vocab_size if top_k == -1 else top_k - if i < input_metadata.num_prompts: - # A prompt input. - top_ps.append(top_p) - top_ks.append(top_k) - else: - # A generation token. - top_ps += [top_p] * len(seq_ids) - top_ks += [top_k] * len(seq_ids) + top_ps += [top_p] * len(seq_ids) + top_ks += [top_k] * len(seq_ids) return top_ps, top_ks