mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:24:56 +08:00
[FIX] Simplify sampler logic (#1156)
This commit is contained in:
parent
947b794146
commit
f187877945
@ -133,37 +133,22 @@ def _get_penalties(
|
||||
# Collect the presence and frequency penalties.
|
||||
presence_penalties: List[float] = []
|
||||
frequency_penalties: List[float] = []
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
for seq_group in input_metadata.seq_groups:
|
||||
seq_ids, sampling_params = seq_group
|
||||
p = sampling_params.presence_penalty
|
||||
f = sampling_params.frequency_penalty
|
||||
if i < input_metadata.num_prompts:
|
||||
# A prompt input.
|
||||
presence_penalties.append(p)
|
||||
frequency_penalties.append(f)
|
||||
else:
|
||||
# A generation token.
|
||||
presence_penalties += [p] * len(seq_ids)
|
||||
frequency_penalties += [f] * len(seq_ids)
|
||||
presence_penalties += [p] * len(seq_ids)
|
||||
frequency_penalties += [f] * len(seq_ids)
|
||||
return presence_penalties, frequency_penalties
|
||||
|
||||
|
||||
def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
|
||||
output_tokens: List[List[int]] = []
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
for seq_group in input_metadata.seq_groups:
|
||||
seq_ids, _ = seq_group
|
||||
if i < input_metadata.num_prompts:
|
||||
# A prompt input.
|
||||
# NOTE: While the prompt input usually has no output tokens,
|
||||
# it may have output tokens in the case of recomputation.
|
||||
seq_id = seq_ids[0]
|
||||
for seq_id in seq_ids:
|
||||
seq_data = input_metadata.seq_data[seq_id]
|
||||
output_tokens.append(seq_data.output_token_ids)
|
||||
else:
|
||||
# A generation token.
|
||||
for seq_id in seq_ids:
|
||||
seq_data = input_metadata.seq_data[seq_id]
|
||||
output_tokens.append(seq_data.output_token_ids)
|
||||
return output_tokens
|
||||
|
||||
|
||||
@ -221,7 +206,7 @@ def _apply_penalties(
|
||||
def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
|
||||
# Collect the temperatures for the logits.
|
||||
temperatures: List[float] = []
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
for seq_group in input_metadata.seq_groups:
|
||||
seq_ids, sampling_params = seq_group
|
||||
temperature = sampling_params.temperature
|
||||
if temperature < _SAMPLING_EPS:
|
||||
@ -229,13 +214,7 @@ def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
|
||||
# (i.e., greedy sampling or beam search).
|
||||
# Set the temperature to 1 to avoid division by zero.
|
||||
temperature = 1.0
|
||||
|
||||
if i < input_metadata.num_prompts:
|
||||
# A prompt input.
|
||||
temperatures.append(temperature)
|
||||
else:
|
||||
# A generation token.
|
||||
temperatures += [temperature] * len(seq_ids)
|
||||
temperatures += [temperature] * len(seq_ids)
|
||||
return temperatures
|
||||
|
||||
|
||||
@ -245,21 +224,15 @@ def _get_top_p_top_k(
|
||||
) -> Tuple[List[float], List[int]]:
|
||||
top_ps: List[float] = []
|
||||
top_ks: List[int] = []
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
for seq_group in input_metadata.seq_groups:
|
||||
seq_ids, sampling_params = seq_group
|
||||
top_p = sampling_params.top_p
|
||||
# k should not be greater than the vocab size.
|
||||
top_k = min(sampling_params.top_k, vocab_size)
|
||||
# k=-1 means no truncation.
|
||||
top_k = vocab_size if top_k == -1 else top_k
|
||||
if i < input_metadata.num_prompts:
|
||||
# A prompt input.
|
||||
top_ps.append(top_p)
|
||||
top_ks.append(top_k)
|
||||
else:
|
||||
# A generation token.
|
||||
top_ps += [top_p] * len(seq_ids)
|
||||
top_ks += [top_k] * len(seq_ids)
|
||||
top_ps += [top_p] * len(seq_ids)
|
||||
top_ks += [top_k] * len(seq_ids)
|
||||
return top_ps, top_ks
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user