mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 08:35:01 +08:00
[Bugfix] handle prompt_logprobs in _apply_min_tokens_penalty (#3876)
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
This commit is contained in:
parent
b3104b2a10
commit
0258b7a94b
@ -1,3 +1,4 @@
|
|||||||
|
import itertools
|
||||||
import random
|
import random
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
@ -194,11 +195,15 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
|||||||
|
|
||||||
def create_sampling_params(min_tokens,
|
def create_sampling_params(min_tokens,
|
||||||
eos_token_id=0,
|
eos_token_id=0,
|
||||||
stop_token_ids=None):
|
*,
|
||||||
|
stop_token_ids: Optional[List[str]] = None,
|
||||||
|
prompt_logprobs: Optional[int] = None):
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
min_tokens=min_tokens,
|
min_tokens=min_tokens,
|
||||||
max_tokens=9999, # keep higher than max of min_tokens
|
max_tokens=9999, # keep higher than max of min_tokens
|
||||||
stop_token_ids=stop_token_ids,
|
stop_token_ids=stop_token_ids,
|
||||||
|
# requesting prompt_logprobs changes the structure of `logits`
|
||||||
|
prompt_logprobs=prompt_logprobs,
|
||||||
)
|
)
|
||||||
sampling_params.eos_token_id = eos_token_id
|
sampling_params.eos_token_id = eos_token_id
|
||||||
return sampling_params
|
return sampling_params
|
||||||
@ -217,9 +222,9 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
|||||||
|
|
||||||
expected_penalization = []
|
expected_penalization = []
|
||||||
sequence_metadata_list = []
|
sequence_metadata_list = []
|
||||||
|
# 20% chance to generate seq group metadata list with all prompts
|
||||||
|
is_prompt = random.random() < 0.2
|
||||||
while batch_size > 0:
|
while batch_size > 0:
|
||||||
# 20% chance to generate prompt seq group with single sequence
|
|
||||||
is_prompt = random.random() < 0.2
|
|
||||||
num_seqs = 1 if is_prompt else random.randint(1, batch_size)
|
num_seqs = 1 if is_prompt else random.randint(1, batch_size)
|
||||||
|
|
||||||
eos_token_id = random.randint(0, VOCAB_SIZE - 1)
|
eos_token_id = random.randint(0, VOCAB_SIZE - 1)
|
||||||
@ -240,7 +245,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
|||||||
seq_group_penalization = []
|
seq_group_penalization = []
|
||||||
for _ in range(num_seqs):
|
for _ in range(num_seqs):
|
||||||
num_input = random.randint(1, 100)
|
num_input = random.randint(1, 100)
|
||||||
num_generated = random.randint(1, 100) if not is_prompt else 0
|
num_generated = 0 if is_prompt else random.randint(1, 100)
|
||||||
seq_data[next(seq_id_counter)] = create_sequence_data(
|
seq_data[next(seq_id_counter)] = create_sequence_data(
|
||||||
num_input=num_input, num_generated=num_generated)
|
num_input=num_input, num_generated=num_generated)
|
||||||
seq_group_penalization.append(num_generated < min_tokens)
|
seq_group_penalization.append(num_generated < min_tokens)
|
||||||
@ -292,6 +297,21 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
prompt_with_penalization_and_prompt_logprobs = {
|
||||||
|
"expected_penalization": [False, False, True],
|
||||||
|
"seq_group_metadata_list": [
|
||||||
|
SequenceGroupMetadata(
|
||||||
|
request_id="test_1",
|
||||||
|
is_prompt=True,
|
||||||
|
seq_data={
|
||||||
|
next(seq_id_counter): create_sequence_data(num_input=3),
|
||||||
|
},
|
||||||
|
sampling_params=create_sampling_params(1, prompt_logprobs=3),
|
||||||
|
block_tables={},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
stop_penalizing_after_min_tokens = {
|
stop_penalizing_after_min_tokens = {
|
||||||
"expected_penalization": [False],
|
"expected_penalization": [False],
|
||||||
"seq_group_metadata_list": [
|
"seq_group_metadata_list": [
|
||||||
@ -309,8 +329,34 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
|||||||
}
|
}
|
||||||
|
|
||||||
stop_token_ids = [42, 99, 42, 0] # intentional duplication
|
stop_token_ids = [42, 99, 42, 0] # intentional duplication
|
||||||
simple_combination = {
|
prompt_combination = {
|
||||||
"expected_penalization": [True, False, False],
|
"expected_penalization": [False, True, False],
|
||||||
|
"seq_group_metadata_list": [
|
||||||
|
SequenceGroupMetadata(
|
||||||
|
request_id="test_2",
|
||||||
|
is_prompt=True,
|
||||||
|
seq_data={
|
||||||
|
next(seq_id_counter): create_sequence_data(num_input=2),
|
||||||
|
},
|
||||||
|
sampling_params=create_sampling_params(1, prompt_logprobs=3),
|
||||||
|
block_tables={},
|
||||||
|
),
|
||||||
|
SequenceGroupMetadata(
|
||||||
|
request_id="test_3",
|
||||||
|
is_prompt=True,
|
||||||
|
seq_data={
|
||||||
|
next(seq_id_counter): create_sequence_data(),
|
||||||
|
},
|
||||||
|
sampling_params=create_sampling_params(
|
||||||
|
0, stop_token_ids=stop_token_ids),
|
||||||
|
block_tables={},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
stop_token_ids = [1, 999, 37, 37] # intentional duplication
|
||||||
|
decode_combination = {
|
||||||
|
"expected_penalization": [True, False, False, True, False],
|
||||||
"seq_group_metadata_list": [
|
"seq_group_metadata_list": [
|
||||||
SequenceGroupMetadata(
|
SequenceGroupMetadata(
|
||||||
request_id="test_1",
|
request_id="test_1",
|
||||||
@ -327,14 +373,19 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
|||||||
),
|
),
|
||||||
SequenceGroupMetadata(
|
SequenceGroupMetadata(
|
||||||
request_id="test_2",
|
request_id="test_2",
|
||||||
is_prompt=True,
|
is_prompt=False,
|
||||||
seq_data={
|
seq_data={
|
||||||
next(seq_id_counter): create_sequence_data(),
|
next(seq_id_counter):
|
||||||
|
create_sequence_data(num_generated=20),
|
||||||
|
next(seq_id_counter):
|
||||||
|
create_sequence_data(num_generated=1),
|
||||||
|
next(seq_id_counter):
|
||||||
|
create_sequence_data(num_generated=10),
|
||||||
},
|
},
|
||||||
sampling_params=create_sampling_params(
|
sampling_params=create_sampling_params(
|
||||||
0, stop_token_ids=stop_token_ids),
|
10, prompt_logprobs=5, stop_token_ids=stop_token_ids),
|
||||||
block_tables={},
|
block_tables={},
|
||||||
)
|
),
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -342,8 +393,10 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
|||||||
test_cases = [
|
test_cases = [
|
||||||
prompt_without_penalization,
|
prompt_without_penalization,
|
||||||
prompt_with_penalization,
|
prompt_with_penalization,
|
||||||
|
prompt_with_penalization_and_prompt_logprobs,
|
||||||
stop_penalizing_after_min_tokens,
|
stop_penalizing_after_min_tokens,
|
||||||
simple_combination,
|
prompt_combination,
|
||||||
|
decode_combination,
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
test_cases = [generate_test_case()]
|
test_cases = [generate_test_case()]
|
||||||
@ -351,30 +404,49 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
|||||||
def run_test_case(*,
|
def run_test_case(*,
|
||||||
expected_penalization=None,
|
expected_penalization=None,
|
||||||
seq_group_metadata_list=None):
|
seq_group_metadata_list=None):
|
||||||
assert expected_penalization, "Invalid test case"
|
assert expected_penalization, \
|
||||||
assert seq_group_metadata_list, "Invalid test case"
|
"Invalid test case, need expected_penalization"
|
||||||
|
assert seq_group_metadata_list, \
|
||||||
|
"Invalid test case, need seq_group_metadata_list"
|
||||||
|
|
||||||
batch_size = 0
|
batch_size = 0
|
||||||
prompt_lens = []
|
prompt_lens = []
|
||||||
sampling_params_per_seq = []
|
sampling_params_per_row = []
|
||||||
for sgm in seq_group_metadata_list:
|
for sgm in seq_group_metadata_list:
|
||||||
num_seqs = len(sgm.seq_data)
|
|
||||||
batch_size += num_seqs
|
|
||||||
sampling_params = sgm.sampling_params
|
sampling_params = sgm.sampling_params
|
||||||
for seq_id in sgm.seq_data:
|
|
||||||
prompt_lens.append(sgm.seq_data[seq_id].get_prompt_len())
|
num_rows = len(sgm.seq_data)
|
||||||
sampling_params_per_seq.append(sampling_params)
|
if sgm.is_prompt:
|
||||||
|
# a prompt seq_group has only one sequence
|
||||||
|
seq_data = next(iter(sgm.seq_data.values()))
|
||||||
|
prompt_len = seq_data.get_prompt_len()
|
||||||
|
prompt_lens.append(prompt_len)
|
||||||
|
|
||||||
|
if sgm.sampling_params.prompt_logprobs:
|
||||||
|
# with prompt_logprobs each token in the prompt has a row in
|
||||||
|
# logits
|
||||||
|
num_rows = prompt_len
|
||||||
|
|
||||||
|
batch_size += num_rows
|
||||||
|
sampling_params_per_row.extend(
|
||||||
|
itertools.repeat(sampling_params, num_rows))
|
||||||
|
|
||||||
|
assert len(
|
||||||
|
expected_penalization
|
||||||
|
) == batch_size, \
|
||||||
|
("Invalid test case, expected_penalization does not match computed"
|
||||||
|
"batch size")
|
||||||
|
|
||||||
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
|
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
|
||||||
sampling_metadata = model_runner._prepare_sample(
|
sampling_metadata = model_runner._prepare_sample(
|
||||||
seq_group_metadata_list,
|
seq_group_metadata_list,
|
||||||
prompt_lens=prompt_lens,
|
prompt_lens=prompt_lens if prompt_lens else None,
|
||||||
subquery_lens=prompt_lens)
|
subquery_lens=prompt_lens if prompt_lens else None)
|
||||||
# the logits tensor is modified in-place by the sampler
|
# the logits tensor is modified in-place by the sampler
|
||||||
_ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
|
_ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
|
||||||
|
|
||||||
for logits_idx, (should_penalize, sampling_params) in enumerate(
|
for logits_idx, (should_penalize, sampling_params) in enumerate(
|
||||||
zip(expected_penalization, sampling_params_per_seq)):
|
zip(expected_penalization, sampling_params_per_row)):
|
||||||
|
|
||||||
tokens_to_check = [sampling_params.eos_token_id]
|
tokens_to_check = [sampling_params.eos_token_id]
|
||||||
if sampling_params.stop_token_ids:
|
if sampling_params.stop_token_ids:
|
||||||
|
|||||||
@ -27,6 +27,12 @@ class Sampler(nn.Module):
|
|||||||
6. Sample the next tokens.
|
6. Sample the next tokens.
|
||||||
Here, each sequence group within the batch can have different sampling
|
Here, each sequence group within the batch can have different sampling
|
||||||
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
|
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
|
||||||
|
|
||||||
|
The structure of the logits tensor is coupled with the seq_groups in
|
||||||
|
sampling_metadata. Typically, each sequence in each seq_group has one row in
|
||||||
|
logits for the next token to be sampled; however, for a seq_group with a
|
||||||
|
prompt request with the prompt_logprobs sampling parameter, there are rows
|
||||||
|
in logits for each token in the input prompt.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -106,7 +112,16 @@ def _apply_min_tokens_penalty(
|
|||||||
# list of indices in logits that will be set to -inf
|
# list of indices in logits that will be set to -inf
|
||||||
logits_to_penalize = []
|
logits_to_penalize = []
|
||||||
start_idx = 0
|
start_idx = 0
|
||||||
for seq_ids, sampling_params in sampling_metadata.seq_groups:
|
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
||||||
|
seq_ids, sampling_params = seq_group
|
||||||
|
|
||||||
|
# handle prompt_logprobs by skipping rows in logits added for the prompt
|
||||||
|
# tokens (prompt logprobs are not penalized)
|
||||||
|
if (i < sampling_metadata.num_prompts
|
||||||
|
and sampling_params.prompt_logprobs is not None):
|
||||||
|
assert len(seq_ids) == 1
|
||||||
|
start_idx += sampling_metadata.prompt_lens[i] - 1
|
||||||
|
|
||||||
min_tokens = sampling_params.min_tokens
|
min_tokens = sampling_params.min_tokens
|
||||||
if min_tokens > 0:
|
if min_tokens > 0:
|
||||||
seqs_to_penalize = []
|
seqs_to_penalize = []
|
||||||
@ -132,6 +147,8 @@ def _apply_min_tokens_penalty(
|
|||||||
# eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
|
# eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
|
||||||
logits[tuple(zip(*logits_to_penalize))] = -float("inf")
|
logits[tuple(zip(*logits_to_penalize))] = -float("inf")
|
||||||
|
|
||||||
|
# verifies that no rows in logits were missed unexpectedly
|
||||||
|
assert start_idx == logits.shape[0]
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user