mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-31 20:27:08 +08:00
[BugFix] Fix min_tokens when eos_token_id is None (#4389)
Co-authored-by: DefTruth <31974251+deftruth@users.noreply.github.com>
This commit is contained in:
parent
dfea173148
commit
81661da7b2
@ -207,7 +207,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
|||||||
def create_sampling_params(min_tokens,
|
def create_sampling_params(min_tokens,
|
||||||
eos_token_id=0,
|
eos_token_id=0,
|
||||||
*,
|
*,
|
||||||
stop_token_ids: Optional[List[str]] = None,
|
stop_token_ids: Optional[List[int]] = None,
|
||||||
prompt_logprobs: Optional[int] = None):
|
prompt_logprobs: Optional[int] = None):
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
min_tokens=min_tokens,
|
min_tokens=min_tokens,
|
||||||
@ -216,7 +216,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
|||||||
# requesting prompt_logprobs changes the structure of `logits`
|
# requesting prompt_logprobs changes the structure of `logits`
|
||||||
prompt_logprobs=prompt_logprobs,
|
prompt_logprobs=prompt_logprobs,
|
||||||
)
|
)
|
||||||
sampling_params.eos_token_id = eos_token_id
|
sampling_params.all_stop_token_ids.add(eos_token_id)
|
||||||
return sampling_params
|
return sampling_params
|
||||||
|
|
||||||
def create_sequence_data(num_input=3, num_generated=0):
|
def create_sequence_data(num_input=3, num_generated=0):
|
||||||
@ -461,10 +461,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
|||||||
for logits_idx, (should_penalize, sampling_params) in enumerate(
|
for logits_idx, (should_penalize, sampling_params) in enumerate(
|
||||||
zip(expected_penalization, sampling_params_per_row)):
|
zip(expected_penalization, sampling_params_per_row)):
|
||||||
|
|
||||||
tokens_to_check = [sampling_params.eos_token_id]
|
tokens_to_check = sampling_params.all_stop_token_ids
|
||||||
if sampling_params.stop_token_ids:
|
|
||||||
tokens_to_check.extend(sampling_params.stop_token_ids)
|
|
||||||
tokens_to_check = set(tokens_to_check)
|
|
||||||
|
|
||||||
if should_penalize:
|
if should_penalize:
|
||||||
for token_id in tokens_to_check:
|
for token_id in tokens_to_check:
|
||||||
|
|||||||
@ -431,9 +431,10 @@ class LLMEngine:
|
|||||||
# Defensive copy of SamplingParams, which are used by the sampler,
|
# Defensive copy of SamplingParams, which are used by the sampler,
|
||||||
# this doesn't deep-copy LogitsProcessor objects
|
# this doesn't deep-copy LogitsProcessor objects
|
||||||
sampling_params = sampling_params.clone()
|
sampling_params = sampling_params.clone()
|
||||||
# inject the eos token id into the sampling_params to support min_tokens
|
# Add the eos token id into the sampling_params to support min_tokens
|
||||||
# processing
|
# processing
|
||||||
sampling_params.eos_token_id = seq.eos_token_id
|
if seq.eos_token_id is not None:
|
||||||
|
sampling_params.all_stop_token_ids.add(seq.eos_token_id)
|
||||||
sampling_params.update_from_generation_config(
|
sampling_params.update_from_generation_config(
|
||||||
self.generation_config_fields)
|
self.generation_config_fields)
|
||||||
|
|
||||||
|
|||||||
@ -169,19 +169,17 @@ def _apply_min_tokens_penalty(
|
|||||||
|
|
||||||
start_idx = sample_indices[0]
|
start_idx = sample_indices[0]
|
||||||
min_tokens = sampling_params.min_tokens
|
min_tokens = sampling_params.min_tokens
|
||||||
if min_tokens > 0:
|
token_ids_to_penalize = sampling_params.all_stop_token_ids
|
||||||
|
if min_tokens > 0 and token_ids_to_penalize:
|
||||||
seqs_to_penalize = []
|
seqs_to_penalize = []
|
||||||
for i, seq_id in enumerate(seq_ids):
|
for j, seq_id in enumerate(seq_ids):
|
||||||
seq_data = seq_group.seq_data[seq_id]
|
seq_data = seq_group.seq_data[seq_id]
|
||||||
if len(seq_data.output_token_ids) < min_tokens:
|
if len(seq_data.output_token_ids) < min_tokens:
|
||||||
seqs_to_penalize.append(i)
|
seqs_to_penalize.append(j)
|
||||||
|
|
||||||
if seqs_to_penalize:
|
if seqs_to_penalize:
|
||||||
# convert to the index into logits
|
# convert to the index into logits
|
||||||
seqs_to_penalize = [start_idx + i for i in seqs_to_penalize]
|
seqs_to_penalize = [start_idx + j for j in seqs_to_penalize]
|
||||||
# use set() to remove any duplicates
|
|
||||||
token_ids_to_penalize = set(sampling_params.stop_token_ids +
|
|
||||||
[sampling_params.eos_token_id])
|
|
||||||
# itertools.product pairs each seq index with every token id
|
# itertools.product pairs each seq index with every token id
|
||||||
logits_to_penalize.extend(
|
logits_to_penalize.extend(
|
||||||
itertools.product(seqs_to_penalize, token_ids_to_penalize))
|
itertools.product(seqs_to_penalize, token_ids_to_penalize))
|
||||||
|
|||||||
@ -185,8 +185,8 @@ class SamplingParams:
|
|||||||
self.top_k = -1
|
self.top_k = -1
|
||||||
self.min_p = 0.0
|
self.min_p = 0.0
|
||||||
self._verify_greedy_sampling()
|
self._verify_greedy_sampling()
|
||||||
# injected by the engine
|
# eos_token_id is added to this by the engine
|
||||||
self.eos_token_id = None
|
self.all_stop_token_ids = set(self.stop_token_ids)
|
||||||
|
|
||||||
def _verify_args(self) -> None:
|
def _verify_args(self) -> None:
|
||||||
if self.n < 1:
|
if self.n < 1:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user