[Bugfix] handle prompt_logprobs in _apply_min_tokens_penalty (#3876)

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
This commit is contained in:
Travis Johnson 2024-04-10 02:39:56 -06:00 committed by GitHub
parent b3104b2a10
commit 0258b7a94b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 112 additions and 23 deletions

View File

@ -1,3 +1,4 @@
import itertools
import random
from typing import List, Optional, Tuple
from unittest.mock import patch
@ -194,11 +195,15 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
def create_sampling_params(min_tokens,
eos_token_id=0,
stop_token_ids=None):
*,
stop_token_ids: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None):
sampling_params = SamplingParams(
min_tokens=min_tokens,
max_tokens=9999, # keep higher than max of min_tokens
stop_token_ids=stop_token_ids,
# requesting prompt_logprobs changes the structure of `logits`
prompt_logprobs=prompt_logprobs,
)
sampling_params.eos_token_id = eos_token_id
return sampling_params
@ -217,9 +222,9 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
expected_penalization = []
sequence_metadata_list = []
# 20% chance to generate seq group metadata list with all prompts
is_prompt = random.random() < 0.2
while batch_size > 0:
# 20% chance to generate prompt seq group with single sequence
is_prompt = random.random() < 0.2
num_seqs = 1 if is_prompt else random.randint(1, batch_size)
eos_token_id = random.randint(0, VOCAB_SIZE - 1)
@ -240,7 +245,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
seq_group_penalization = []
for _ in range(num_seqs):
num_input = random.randint(1, 100)
num_generated = random.randint(1, 100) if not is_prompt else 0
num_generated = 0 if is_prompt else random.randint(1, 100)
seq_data[next(seq_id_counter)] = create_sequence_data(
num_input=num_input, num_generated=num_generated)
seq_group_penalization.append(num_generated < min_tokens)
@ -292,6 +297,21 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
]
}
prompt_with_penalization_and_prompt_logprobs = {
"expected_penalization": [False, False, True],
"seq_group_metadata_list": [
SequenceGroupMetadata(
request_id="test_1",
is_prompt=True,
seq_data={
next(seq_id_counter): create_sequence_data(num_input=3),
},
sampling_params=create_sampling_params(1, prompt_logprobs=3),
block_tables={},
),
]
}
stop_penalizing_after_min_tokens = {
"expected_penalization": [False],
"seq_group_metadata_list": [
@ -309,8 +329,34 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
}
stop_token_ids = [42, 99, 42, 0] # intentional duplication
simple_combination = {
"expected_penalization": [True, False, False],
prompt_combination = {
"expected_penalization": [False, True, False],
"seq_group_metadata_list": [
SequenceGroupMetadata(
request_id="test_2",
is_prompt=True,
seq_data={
next(seq_id_counter): create_sequence_data(num_input=2),
},
sampling_params=create_sampling_params(1, prompt_logprobs=3),
block_tables={},
),
SequenceGroupMetadata(
request_id="test_3",
is_prompt=True,
seq_data={
next(seq_id_counter): create_sequence_data(),
},
sampling_params=create_sampling_params(
0, stop_token_ids=stop_token_ids),
block_tables={},
)
]
}
stop_token_ids = [1, 999, 37, 37] # intentional duplication
decode_combination = {
"expected_penalization": [True, False, False, True, False],
"seq_group_metadata_list": [
SequenceGroupMetadata(
request_id="test_1",
@ -327,14 +373,19 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
),
SequenceGroupMetadata(
request_id="test_2",
is_prompt=True,
is_prompt=False,
seq_data={
next(seq_id_counter): create_sequence_data(),
next(seq_id_counter):
create_sequence_data(num_generated=20),
next(seq_id_counter):
create_sequence_data(num_generated=1),
next(seq_id_counter):
create_sequence_data(num_generated=10),
},
sampling_params=create_sampling_params(
0, stop_token_ids=stop_token_ids),
10, prompt_logprobs=5, stop_token_ids=stop_token_ids),
block_tables={},
)
),
]
}
@ -342,8 +393,10 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
test_cases = [
prompt_without_penalization,
prompt_with_penalization,
prompt_with_penalization_and_prompt_logprobs,
stop_penalizing_after_min_tokens,
simple_combination,
prompt_combination,
decode_combination,
]
else:
test_cases = [generate_test_case()]
@ -351,30 +404,49 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
def run_test_case(*,
expected_penalization=None,
seq_group_metadata_list=None):
assert expected_penalization, "Invalid test case"
assert seq_group_metadata_list, "Invalid test case"
assert expected_penalization, \
"Invalid test case, need expected_penalization"
assert seq_group_metadata_list, \
"Invalid test case, need seq_group_metadata_list"
batch_size = 0
prompt_lens = []
sampling_params_per_seq = []
sampling_params_per_row = []
for sgm in seq_group_metadata_list:
num_seqs = len(sgm.seq_data)
batch_size += num_seqs
sampling_params = sgm.sampling_params
for seq_id in sgm.seq_data:
prompt_lens.append(sgm.seq_data[seq_id].get_prompt_len())
sampling_params_per_seq.append(sampling_params)
num_rows = len(sgm.seq_data)
if sgm.is_prompt:
# a prompt seq_group has only one sequence
seq_data = next(iter(sgm.seq_data.values()))
prompt_len = seq_data.get_prompt_len()
prompt_lens.append(prompt_len)
if sgm.sampling_params.prompt_logprobs:
# with prompt_logprobs each token in the prompt has a row in
# logits
num_rows = prompt_len
batch_size += num_rows
sampling_params_per_row.extend(
itertools.repeat(sampling_params, num_rows))
assert len(
expected_penalization
) == batch_size, \
("Invalid test case, expected_penalization does not match computed"
"batch size")
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
sampling_metadata = model_runner._prepare_sample(
seq_group_metadata_list,
prompt_lens=prompt_lens,
subquery_lens=prompt_lens)
prompt_lens=prompt_lens if prompt_lens else None,
subquery_lens=prompt_lens if prompt_lens else None)
# the logits tensor is modified in-place by the sampler
_ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
for logits_idx, (should_penalize, sampling_params) in enumerate(
zip(expected_penalization, sampling_params_per_seq)):
zip(expected_penalization, sampling_params_per_row)):
tokens_to_check = [sampling_params.eos_token_id]
if sampling_params.stop_token_ids:

View File

@ -27,6 +27,12 @@ class Sampler(nn.Module):
6. Sample the next tokens.
Here, each sequence group within the batch can have different sampling
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
The structure of the logits tensor is coupled with the seq_groups in
sampling_metadata. Typically, each sequence in each seq_group has one row in
logits for the next token to be sampled; however, for a seq_group with a
prompt request with the prompt_logprobs sampling parameter, there are rows
in logits for each token in the input prompt.
"""
def forward(
@ -106,7 +112,16 @@ def _apply_min_tokens_penalty(
# list of indices in logits that will be set to -inf
logits_to_penalize = []
start_idx = 0
for seq_ids, sampling_params in sampling_metadata.seq_groups:
for i, seq_group in enumerate(sampling_metadata.seq_groups):
seq_ids, sampling_params = seq_group
# handle prompt_logprobs by skipping rows in logits added for the prompt
# tokens (prompt logprobs are not penalized)
if (i < sampling_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
assert len(seq_ids) == 1
start_idx += sampling_metadata.prompt_lens[i] - 1
min_tokens = sampling_params.min_tokens
if min_tokens > 0:
seqs_to_penalize = []
@ -132,6 +147,8 @@ def _apply_min_tokens_penalty(
# eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
logits[tuple(zip(*logits_to_penalize))] = -float("inf")
# verifies that no rows in logits were missed unexpectedly
assert start_idx == logits.shape[0]
return logits