[Bugfix] handle prompt_logprobs in _apply_min_tokens_penalty (#3876)

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
This commit is contained in:
Travis Johnson 2024-04-10 02:39:56 -06:00 committed by GitHub
parent b3104b2a10
commit 0258b7a94b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 112 additions and 23 deletions

View File

@ -1,3 +1,4 @@
import itertools
import random import random
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from unittest.mock import patch from unittest.mock import patch
@ -194,11 +195,15 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
def create_sampling_params(min_tokens, def create_sampling_params(min_tokens,
eos_token_id=0, eos_token_id=0,
stop_token_ids=None): *,
stop_token_ids: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None):
sampling_params = SamplingParams( sampling_params = SamplingParams(
min_tokens=min_tokens, min_tokens=min_tokens,
max_tokens=9999, # keep higher than max of min_tokens max_tokens=9999, # keep higher than max of min_tokens
stop_token_ids=stop_token_ids, stop_token_ids=stop_token_ids,
# requesting prompt_logprobs changes the structure of `logits`
prompt_logprobs=prompt_logprobs,
) )
sampling_params.eos_token_id = eos_token_id sampling_params.eos_token_id = eos_token_id
return sampling_params return sampling_params
@ -217,9 +222,9 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
expected_penalization = [] expected_penalization = []
sequence_metadata_list = [] sequence_metadata_list = []
# 20% chance to generate seq group metadata list with all prompts
is_prompt = random.random() < 0.2
while batch_size > 0: while batch_size > 0:
# 20% chance to generate prompt seq group with single sequence
is_prompt = random.random() < 0.2
num_seqs = 1 if is_prompt else random.randint(1, batch_size) num_seqs = 1 if is_prompt else random.randint(1, batch_size)
eos_token_id = random.randint(0, VOCAB_SIZE - 1) eos_token_id = random.randint(0, VOCAB_SIZE - 1)
@ -240,7 +245,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
seq_group_penalization = [] seq_group_penalization = []
for _ in range(num_seqs): for _ in range(num_seqs):
num_input = random.randint(1, 100) num_input = random.randint(1, 100)
num_generated = random.randint(1, 100) if not is_prompt else 0 num_generated = 0 if is_prompt else random.randint(1, 100)
seq_data[next(seq_id_counter)] = create_sequence_data( seq_data[next(seq_id_counter)] = create_sequence_data(
num_input=num_input, num_generated=num_generated) num_input=num_input, num_generated=num_generated)
seq_group_penalization.append(num_generated < min_tokens) seq_group_penalization.append(num_generated < min_tokens)
@ -292,6 +297,21 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
] ]
} }
prompt_with_penalization_and_prompt_logprobs = {
"expected_penalization": [False, False, True],
"seq_group_metadata_list": [
SequenceGroupMetadata(
request_id="test_1",
is_prompt=True,
seq_data={
next(seq_id_counter): create_sequence_data(num_input=3),
},
sampling_params=create_sampling_params(1, prompt_logprobs=3),
block_tables={},
),
]
}
stop_penalizing_after_min_tokens = { stop_penalizing_after_min_tokens = {
"expected_penalization": [False], "expected_penalization": [False],
"seq_group_metadata_list": [ "seq_group_metadata_list": [
@ -309,8 +329,34 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
} }
stop_token_ids = [42, 99, 42, 0] # intentional duplication stop_token_ids = [42, 99, 42, 0] # intentional duplication
simple_combination = { prompt_combination = {
"expected_penalization": [True, False, False], "expected_penalization": [False, True, False],
"seq_group_metadata_list": [
SequenceGroupMetadata(
request_id="test_2",
is_prompt=True,
seq_data={
next(seq_id_counter): create_sequence_data(num_input=2),
},
sampling_params=create_sampling_params(1, prompt_logprobs=3),
block_tables={},
),
SequenceGroupMetadata(
request_id="test_3",
is_prompt=True,
seq_data={
next(seq_id_counter): create_sequence_data(),
},
sampling_params=create_sampling_params(
0, stop_token_ids=stop_token_ids),
block_tables={},
)
]
}
stop_token_ids = [1, 999, 37, 37] # intentional duplication
decode_combination = {
"expected_penalization": [True, False, False, True, False],
"seq_group_metadata_list": [ "seq_group_metadata_list": [
SequenceGroupMetadata( SequenceGroupMetadata(
request_id="test_1", request_id="test_1",
@ -327,14 +373,19 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
), ),
SequenceGroupMetadata( SequenceGroupMetadata(
request_id="test_2", request_id="test_2",
is_prompt=True, is_prompt=False,
seq_data={ seq_data={
next(seq_id_counter): create_sequence_data(), next(seq_id_counter):
create_sequence_data(num_generated=20),
next(seq_id_counter):
create_sequence_data(num_generated=1),
next(seq_id_counter):
create_sequence_data(num_generated=10),
}, },
sampling_params=create_sampling_params( sampling_params=create_sampling_params(
0, stop_token_ids=stop_token_ids), 10, prompt_logprobs=5, stop_token_ids=stop_token_ids),
block_tables={}, block_tables={},
) ),
] ]
} }
@ -342,8 +393,10 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
test_cases = [ test_cases = [
prompt_without_penalization, prompt_without_penalization,
prompt_with_penalization, prompt_with_penalization,
prompt_with_penalization_and_prompt_logprobs,
stop_penalizing_after_min_tokens, stop_penalizing_after_min_tokens,
simple_combination, prompt_combination,
decode_combination,
] ]
else: else:
test_cases = [generate_test_case()] test_cases = [generate_test_case()]
@ -351,30 +404,49 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
def run_test_case(*, def run_test_case(*,
expected_penalization=None, expected_penalization=None,
seq_group_metadata_list=None): seq_group_metadata_list=None):
assert expected_penalization, "Invalid test case" assert expected_penalization, \
assert seq_group_metadata_list, "Invalid test case" "Invalid test case, need expected_penalization"
assert seq_group_metadata_list, \
"Invalid test case, need seq_group_metadata_list"
batch_size = 0 batch_size = 0
prompt_lens = [] prompt_lens = []
sampling_params_per_seq = [] sampling_params_per_row = []
for sgm in seq_group_metadata_list: for sgm in seq_group_metadata_list:
num_seqs = len(sgm.seq_data)
batch_size += num_seqs
sampling_params = sgm.sampling_params sampling_params = sgm.sampling_params
for seq_id in sgm.seq_data:
prompt_lens.append(sgm.seq_data[seq_id].get_prompt_len()) num_rows = len(sgm.seq_data)
sampling_params_per_seq.append(sampling_params) if sgm.is_prompt:
# a prompt seq_group has only one sequence
seq_data = next(iter(sgm.seq_data.values()))
prompt_len = seq_data.get_prompt_len()
prompt_lens.append(prompt_len)
if sgm.sampling_params.prompt_logprobs:
# with prompt_logprobs each token in the prompt has a row in
# logits
num_rows = prompt_len
batch_size += num_rows
sampling_params_per_row.extend(
itertools.repeat(sampling_params, num_rows))
assert len(
expected_penalization
) == batch_size, \
("Invalid test case, expected_penalization does not match computed"
"batch size")
_, fake_logits, sampler, model_runner = _prepare_test(batch_size) _, fake_logits, sampler, model_runner = _prepare_test(batch_size)
sampling_metadata = model_runner._prepare_sample( sampling_metadata = model_runner._prepare_sample(
seq_group_metadata_list, seq_group_metadata_list,
prompt_lens=prompt_lens, prompt_lens=prompt_lens if prompt_lens else None,
subquery_lens=prompt_lens) subquery_lens=prompt_lens if prompt_lens else None)
# the logits tensor is modified in-place by the sampler # the logits tensor is modified in-place by the sampler
_ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata) _ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
for logits_idx, (should_penalize, sampling_params) in enumerate( for logits_idx, (should_penalize, sampling_params) in enumerate(
zip(expected_penalization, sampling_params_per_seq)): zip(expected_penalization, sampling_params_per_row)):
tokens_to_check = [sampling_params.eos_token_id] tokens_to_check = [sampling_params.eos_token_id]
if sampling_params.stop_token_ids: if sampling_params.stop_token_ids:

View File

@ -27,6 +27,12 @@ class Sampler(nn.Module):
6. Sample the next tokens. 6. Sample the next tokens.
Here, each sequence group within the batch can have different sampling Here, each sequence group within the batch can have different sampling
parameters (e.g., sampling method, temperature, top-p, top-k, etc.). parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
The structure of the logits tensor is coupled with the seq_groups in
sampling_metadata. Typically, each sequence in each seq_group has one row in
logits for the next token to be sampled; however, for a seq_group with a
prompt request with the prompt_logprobs sampling parameter, there are rows
in logits for each token in the input prompt.
""" """
def forward( def forward(
@ -106,7 +112,16 @@ def _apply_min_tokens_penalty(
# list of indices in logits that will be set to -inf # list of indices in logits that will be set to -inf
logits_to_penalize = [] logits_to_penalize = []
start_idx = 0 start_idx = 0
for seq_ids, sampling_params in sampling_metadata.seq_groups: for i, seq_group in enumerate(sampling_metadata.seq_groups):
seq_ids, sampling_params = seq_group
# handle prompt_logprobs by skipping rows in logits added for the prompt
# tokens (prompt logprobs are not penalized)
if (i < sampling_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
assert len(seq_ids) == 1
start_idx += sampling_metadata.prompt_lens[i] - 1
min_tokens = sampling_params.min_tokens min_tokens = sampling_params.min_tokens
if min_tokens > 0: if min_tokens > 0:
seqs_to_penalize = [] seqs_to_penalize = []
@ -132,6 +147,8 @@ def _apply_min_tokens_penalty(
# eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) ) # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
logits[tuple(zip(*logits_to_penalize))] = -float("inf") logits[tuple(zip(*logits_to_penalize))] = -float("inf")
# verifies that no rows in logits were missed unexpectedly
assert start_idx == logits.shape[0]
return logits return logits