mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 06:15:01 +08:00
[FIX] Simplify sampler logic (#1156)
This commit is contained in:
parent
947b794146
commit
f187877945
@ -133,37 +133,22 @@ def _get_penalties(
|
|||||||
# Collect the presence and frequency penalties.
|
# Collect the presence and frequency penalties.
|
||||||
presence_penalties: List[float] = []
|
presence_penalties: List[float] = []
|
||||||
frequency_penalties: List[float] = []
|
frequency_penalties: List[float] = []
|
||||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
for seq_group in input_metadata.seq_groups:
|
||||||
seq_ids, sampling_params = seq_group
|
seq_ids, sampling_params = seq_group
|
||||||
p = sampling_params.presence_penalty
|
p = sampling_params.presence_penalty
|
||||||
f = sampling_params.frequency_penalty
|
f = sampling_params.frequency_penalty
|
||||||
if i < input_metadata.num_prompts:
|
presence_penalties += [p] * len(seq_ids)
|
||||||
# A prompt input.
|
frequency_penalties += [f] * len(seq_ids)
|
||||||
presence_penalties.append(p)
|
|
||||||
frequency_penalties.append(f)
|
|
||||||
else:
|
|
||||||
# A generation token.
|
|
||||||
presence_penalties += [p] * len(seq_ids)
|
|
||||||
frequency_penalties += [f] * len(seq_ids)
|
|
||||||
return presence_penalties, frequency_penalties
|
return presence_penalties, frequency_penalties
|
||||||
|
|
||||||
|
|
||||||
def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
|
def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
|
||||||
output_tokens: List[List[int]] = []
|
output_tokens: List[List[int]] = []
|
||||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
for seq_group in input_metadata.seq_groups:
|
||||||
seq_ids, _ = seq_group
|
seq_ids, _ = seq_group
|
||||||
if i < input_metadata.num_prompts:
|
for seq_id in seq_ids:
|
||||||
# A prompt input.
|
|
||||||
# NOTE: While the prompt input usually has no output tokens,
|
|
||||||
# it may have output tokens in the case of recomputation.
|
|
||||||
seq_id = seq_ids[0]
|
|
||||||
seq_data = input_metadata.seq_data[seq_id]
|
seq_data = input_metadata.seq_data[seq_id]
|
||||||
output_tokens.append(seq_data.output_token_ids)
|
output_tokens.append(seq_data.output_token_ids)
|
||||||
else:
|
|
||||||
# A generation token.
|
|
||||||
for seq_id in seq_ids:
|
|
||||||
seq_data = input_metadata.seq_data[seq_id]
|
|
||||||
output_tokens.append(seq_data.output_token_ids)
|
|
||||||
return output_tokens
|
return output_tokens
|
||||||
|
|
||||||
|
|
||||||
@ -221,7 +206,7 @@ def _apply_penalties(
|
|||||||
def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
|
def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
|
||||||
# Collect the temperatures for the logits.
|
# Collect the temperatures for the logits.
|
||||||
temperatures: List[float] = []
|
temperatures: List[float] = []
|
||||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
for seq_group in input_metadata.seq_groups:
|
||||||
seq_ids, sampling_params = seq_group
|
seq_ids, sampling_params = seq_group
|
||||||
temperature = sampling_params.temperature
|
temperature = sampling_params.temperature
|
||||||
if temperature < _SAMPLING_EPS:
|
if temperature < _SAMPLING_EPS:
|
||||||
@ -229,13 +214,7 @@ def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
|
|||||||
# (i.e., greedy sampling or beam search).
|
# (i.e., greedy sampling or beam search).
|
||||||
# Set the temperature to 1 to avoid division by zero.
|
# Set the temperature to 1 to avoid division by zero.
|
||||||
temperature = 1.0
|
temperature = 1.0
|
||||||
|
temperatures += [temperature] * len(seq_ids)
|
||||||
if i < input_metadata.num_prompts:
|
|
||||||
# A prompt input.
|
|
||||||
temperatures.append(temperature)
|
|
||||||
else:
|
|
||||||
# A generation token.
|
|
||||||
temperatures += [temperature] * len(seq_ids)
|
|
||||||
return temperatures
|
return temperatures
|
||||||
|
|
||||||
|
|
||||||
@ -245,21 +224,15 @@ def _get_top_p_top_k(
|
|||||||
) -> Tuple[List[float], List[int]]:
|
) -> Tuple[List[float], List[int]]:
|
||||||
top_ps: List[float] = []
|
top_ps: List[float] = []
|
||||||
top_ks: List[int] = []
|
top_ks: List[int] = []
|
||||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
for seq_group in input_metadata.seq_groups:
|
||||||
seq_ids, sampling_params = seq_group
|
seq_ids, sampling_params = seq_group
|
||||||
top_p = sampling_params.top_p
|
top_p = sampling_params.top_p
|
||||||
# k should not be greater than the vocab size.
|
# k should not be greater than the vocab size.
|
||||||
top_k = min(sampling_params.top_k, vocab_size)
|
top_k = min(sampling_params.top_k, vocab_size)
|
||||||
# k=-1 means no truncation.
|
# k=-1 means no truncation.
|
||||||
top_k = vocab_size if top_k == -1 else top_k
|
top_k = vocab_size if top_k == -1 else top_k
|
||||||
if i < input_metadata.num_prompts:
|
top_ps += [top_p] * len(seq_ids)
|
||||||
# A prompt input.
|
top_ks += [top_k] * len(seq_ids)
|
||||||
top_ps.append(top_p)
|
|
||||||
top_ks.append(top_k)
|
|
||||||
else:
|
|
||||||
# A generation token.
|
|
||||||
top_ps += [top_p] * len(seq_ids)
|
|
||||||
top_ks += [top_k] * len(seq_ids)
|
|
||||||
return top_ps, top_ks
|
return top_ps, top_ks
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user