mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:34:57 +08:00
[Bugfix] handle prompt_logprobs in _apply_min_tokens_penalty (#3876)
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
This commit is contained in:
parent
b3104b2a10
commit
0258b7a94b
@ -1,3 +1,4 @@
|
||||
import itertools
|
||||
import random
|
||||
from typing import List, Optional, Tuple
|
||||
from unittest.mock import patch
|
||||
@ -194,11 +195,15 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
|
||||
def create_sampling_params(min_tokens,
|
||||
eos_token_id=0,
|
||||
stop_token_ids=None):
|
||||
*,
|
||||
stop_token_ids: Optional[List[str]] = None,
|
||||
prompt_logprobs: Optional[int] = None):
|
||||
sampling_params = SamplingParams(
|
||||
min_tokens=min_tokens,
|
||||
max_tokens=9999, # keep higher than max of min_tokens
|
||||
stop_token_ids=stop_token_ids,
|
||||
# requesting prompt_logprobs changes the structure of `logits`
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
)
|
||||
sampling_params.eos_token_id = eos_token_id
|
||||
return sampling_params
|
||||
@ -217,9 +222,9 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
|
||||
expected_penalization = []
|
||||
sequence_metadata_list = []
|
||||
# 20% chance to generate seq group metadata list with all prompts
|
||||
is_prompt = random.random() < 0.2
|
||||
while batch_size > 0:
|
||||
# 20% chance to generate prompt seq group with single sequence
|
||||
is_prompt = random.random() < 0.2
|
||||
num_seqs = 1 if is_prompt else random.randint(1, batch_size)
|
||||
|
||||
eos_token_id = random.randint(0, VOCAB_SIZE - 1)
|
||||
@ -240,7 +245,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
seq_group_penalization = []
|
||||
for _ in range(num_seqs):
|
||||
num_input = random.randint(1, 100)
|
||||
num_generated = random.randint(1, 100) if not is_prompt else 0
|
||||
num_generated = 0 if is_prompt else random.randint(1, 100)
|
||||
seq_data[next(seq_id_counter)] = create_sequence_data(
|
||||
num_input=num_input, num_generated=num_generated)
|
||||
seq_group_penalization.append(num_generated < min_tokens)
|
||||
@ -292,6 +297,21 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
]
|
||||
}
|
||||
|
||||
prompt_with_penalization_and_prompt_logprobs = {
|
||||
"expected_penalization": [False, False, True],
|
||||
"seq_group_metadata_list": [
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_1",
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
next(seq_id_counter): create_sequence_data(num_input=3),
|
||||
},
|
||||
sampling_params=create_sampling_params(1, prompt_logprobs=3),
|
||||
block_tables={},
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
stop_penalizing_after_min_tokens = {
|
||||
"expected_penalization": [False],
|
||||
"seq_group_metadata_list": [
|
||||
@ -309,8 +329,34 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
}
|
||||
|
||||
stop_token_ids = [42, 99, 42, 0] # intentional duplication
|
||||
simple_combination = {
|
||||
"expected_penalization": [True, False, False],
|
||||
prompt_combination = {
|
||||
"expected_penalization": [False, True, False],
|
||||
"seq_group_metadata_list": [
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_2",
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
next(seq_id_counter): create_sequence_data(num_input=2),
|
||||
},
|
||||
sampling_params=create_sampling_params(1, prompt_logprobs=3),
|
||||
block_tables={},
|
||||
),
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_3",
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
next(seq_id_counter): create_sequence_data(),
|
||||
},
|
||||
sampling_params=create_sampling_params(
|
||||
0, stop_token_ids=stop_token_ids),
|
||||
block_tables={},
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
stop_token_ids = [1, 999, 37, 37] # intentional duplication
|
||||
decode_combination = {
|
||||
"expected_penalization": [True, False, False, True, False],
|
||||
"seq_group_metadata_list": [
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_1",
|
||||
@ -327,14 +373,19 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
),
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_2",
|
||||
is_prompt=True,
|
||||
is_prompt=False,
|
||||
seq_data={
|
||||
next(seq_id_counter): create_sequence_data(),
|
||||
next(seq_id_counter):
|
||||
create_sequence_data(num_generated=20),
|
||||
next(seq_id_counter):
|
||||
create_sequence_data(num_generated=1),
|
||||
next(seq_id_counter):
|
||||
create_sequence_data(num_generated=10),
|
||||
},
|
||||
sampling_params=create_sampling_params(
|
||||
0, stop_token_ids=stop_token_ids),
|
||||
10, prompt_logprobs=5, stop_token_ids=stop_token_ids),
|
||||
block_tables={},
|
||||
)
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
@ -342,8 +393,10 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
test_cases = [
|
||||
prompt_without_penalization,
|
||||
prompt_with_penalization,
|
||||
prompt_with_penalization_and_prompt_logprobs,
|
||||
stop_penalizing_after_min_tokens,
|
||||
simple_combination,
|
||||
prompt_combination,
|
||||
decode_combination,
|
||||
]
|
||||
else:
|
||||
test_cases = [generate_test_case()]
|
||||
@ -351,30 +404,49 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
def run_test_case(*,
|
||||
expected_penalization=None,
|
||||
seq_group_metadata_list=None):
|
||||
assert expected_penalization, "Invalid test case"
|
||||
assert seq_group_metadata_list, "Invalid test case"
|
||||
assert expected_penalization, \
|
||||
"Invalid test case, need expected_penalization"
|
||||
assert seq_group_metadata_list, \
|
||||
"Invalid test case, need seq_group_metadata_list"
|
||||
|
||||
batch_size = 0
|
||||
prompt_lens = []
|
||||
sampling_params_per_seq = []
|
||||
sampling_params_per_row = []
|
||||
for sgm in seq_group_metadata_list:
|
||||
num_seqs = len(sgm.seq_data)
|
||||
batch_size += num_seqs
|
||||
sampling_params = sgm.sampling_params
|
||||
for seq_id in sgm.seq_data:
|
||||
prompt_lens.append(sgm.seq_data[seq_id].get_prompt_len())
|
||||
sampling_params_per_seq.append(sampling_params)
|
||||
|
||||
num_rows = len(sgm.seq_data)
|
||||
if sgm.is_prompt:
|
||||
# a prompt seq_group has only one sequence
|
||||
seq_data = next(iter(sgm.seq_data.values()))
|
||||
prompt_len = seq_data.get_prompt_len()
|
||||
prompt_lens.append(prompt_len)
|
||||
|
||||
if sgm.sampling_params.prompt_logprobs:
|
||||
# with prompt_logprobs each token in the prompt has a row in
|
||||
# logits
|
||||
num_rows = prompt_len
|
||||
|
||||
batch_size += num_rows
|
||||
sampling_params_per_row.extend(
|
||||
itertools.repeat(sampling_params, num_rows))
|
||||
|
||||
assert len(
|
||||
expected_penalization
|
||||
) == batch_size, \
|
||||
("Invalid test case, expected_penalization does not match computed"
|
||||
"batch size")
|
||||
|
||||
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
|
||||
sampling_metadata = model_runner._prepare_sample(
|
||||
seq_group_metadata_list,
|
||||
prompt_lens=prompt_lens,
|
||||
subquery_lens=prompt_lens)
|
||||
prompt_lens=prompt_lens if prompt_lens else None,
|
||||
subquery_lens=prompt_lens if prompt_lens else None)
|
||||
# the logits tensor is modified in-place by the sampler
|
||||
_ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
|
||||
|
||||
for logits_idx, (should_penalize, sampling_params) in enumerate(
|
||||
zip(expected_penalization, sampling_params_per_seq)):
|
||||
zip(expected_penalization, sampling_params_per_row)):
|
||||
|
||||
tokens_to_check = [sampling_params.eos_token_id]
|
||||
if sampling_params.stop_token_ids:
|
||||
|
||||
@ -27,6 +27,12 @@ class Sampler(nn.Module):
|
||||
6. Sample the next tokens.
|
||||
Here, each sequence group within the batch can have different sampling
|
||||
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
|
||||
|
||||
The structure of the logits tensor is coupled with the seq_groups in
|
||||
sampling_metadata. Typically, each sequence in each seq_group has one row in
|
||||
logits for the next token to be sampled; however, for a seq_group with a
|
||||
prompt request with the prompt_logprobs sampling parameter, there are rows
|
||||
in logits for each token in the input prompt.
|
||||
"""
|
||||
|
||||
def forward(
|
||||
@ -106,7 +112,16 @@ def _apply_min_tokens_penalty(
|
||||
# list of indices in logits that will be set to -inf
|
||||
logits_to_penalize = []
|
||||
start_idx = 0
|
||||
for seq_ids, sampling_params in sampling_metadata.seq_groups:
|
||||
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
||||
seq_ids, sampling_params = seq_group
|
||||
|
||||
# handle prompt_logprobs by skipping rows in logits added for the prompt
|
||||
# tokens (prompt logprobs are not penalized)
|
||||
if (i < sampling_metadata.num_prompts
|
||||
and sampling_params.prompt_logprobs is not None):
|
||||
assert len(seq_ids) == 1
|
||||
start_idx += sampling_metadata.prompt_lens[i] - 1
|
||||
|
||||
min_tokens = sampling_params.min_tokens
|
||||
if min_tokens > 0:
|
||||
seqs_to_penalize = []
|
||||
@ -132,6 +147,8 @@ def _apply_min_tokens_penalty(
|
||||
# eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
|
||||
logits[tuple(zip(*logits_to_penalize))] = -float("inf")
|
||||
|
||||
# verifies that no rows in logits were missed unexpectedly
|
||||
assert start_idx == logits.shape[0]
|
||||
return logits
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user