feat: implement the min_tokens sampling parameter (#3124)

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
Travis Johnson 2024-03-25 11:14:26 -06:00 committed by GitHub
parent 819924e749
commit c13ad1b7bd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 299 additions and 12 deletions

View File

@ -10,6 +10,7 @@ from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.worker.model_runner import ModelRunner from vllm.worker.model_runner import ModelRunner
from vllm.utils import Counter
class MockLogitsSampler(Sampler): class MockLogitsSampler(Sampler):
@ -25,9 +26,8 @@ class MockLogitsSampler(Sampler):
def _prepare_test( def _prepare_test(
batch_size: int batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, ModelRunner]: ) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, ModelRunner]:
vocab_size = 32000
input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16) input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
fake_logits = torch.full((batch_size, vocab_size), fake_logits = torch.full((batch_size, VOCAB_SIZE),
1e-2, 1e-2,
dtype=input_tensor.dtype) dtype=input_tensor.dtype)
sampler = MockLogitsSampler(fake_logits) sampler = MockLogitsSampler(fake_logits)
@ -35,6 +35,7 @@ def _prepare_test(
return input_tensor, fake_logits, sampler, model_runner return input_tensor, fake_logits, sampler, model_runner
VOCAB_SIZE = 32000
RANDOM_SEEDS = list(range(128)) RANDOM_SEEDS = list(range(128))
CUDA_DEVICES = [ CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
@ -184,6 +185,225 @@ def test_sampler_all_beam(seed: int, device: str):
del model_runner del model_runner
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_min_tokens_penalty(seed: int, device: str):
seq_id_counter = Counter(start=random.randint(0, 100))
set_random_seed(seed)
torch.set_default_device(device)
def create_sampling_params(min_tokens,
eos_token_id=0,
stop_token_ids=None):
sampling_params = SamplingParams(
min_tokens=min_tokens,
max_tokens=9999, # keep higher than max of min_tokens
stop_token_ids=stop_token_ids,
)
sampling_params.eos_token_id = eos_token_id
return sampling_params
def create_sequence_data(num_input=3, num_generated=0):
seq_data = SequenceData(
random.choices(range(0, VOCAB_SIZE), k=num_input))
if num_generated > 0:
seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE),
k=num_generated)
return seq_data
def generate_test_case():
# generate multiple seq groups but limit total batch size
batch_size = random.randint(1, 128)
expected_penalization = []
sequence_metadata_list = []
while batch_size > 0:
# 20% chance to generate prompt seq group with single sequence
is_prompt = random.random() < 0.2
num_seqs = 1 if is_prompt else random.randint(1, batch_size)
eos_token_id = random.randint(0, VOCAB_SIZE - 1)
min_tokens = random.randint(0, 50)
num_stop_tokens = random.randint(0, 8)
if num_stop_tokens > 0:
stop_token_ids = random.choices(range(0, VOCAB_SIZE - 1),
k=num_stop_tokens)
else:
stop_token_ids = None
sampling_params = create_sampling_params(
min_tokens=min_tokens,
eos_token_id=eos_token_id,
stop_token_ids=stop_token_ids)
seq_data = {}
seq_group_penalization = []
for _ in range(num_seqs):
num_input = random.randint(1, 100)
num_generated = random.randint(1, 100) if not is_prompt else 0
seq_data[next(seq_id_counter)] = create_sequence_data(
num_input=num_input, num_generated=num_generated)
seq_group_penalization.append(num_generated < min_tokens)
expected_penalization.extend(seq_group_penalization)
sequence_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{batch_size}",
is_prompt=is_prompt,
seq_data=seq_data,
sampling_params=sampling_params,
block_tables={},
))
batch_size -= num_seqs
return {
"expected_penalization": expected_penalization,
"seq_group_metadata_list": sequence_metadata_list,
}
# define some explicit test cases for edge case behavior
prompt_without_penalization = {
"expected_penalization": [False],
"seq_group_metadata_list": [
SequenceGroupMetadata(
request_id="test_1",
is_prompt=True,
seq_data={
next(seq_id_counter): create_sequence_data(),
},
sampling_params=create_sampling_params(0),
block_tables={},
),
]
}
prompt_with_penalization = {
"expected_penalization": [True],
"seq_group_metadata_list": [
SequenceGroupMetadata(
request_id="test_1",
is_prompt=True,
seq_data={
next(seq_id_counter): create_sequence_data(),
},
sampling_params=create_sampling_params(1),
block_tables={},
),
]
}
stop_penalizing_after_min_tokens = {
"expected_penalization": [False],
"seq_group_metadata_list": [
SequenceGroupMetadata(
request_id="test_1",
is_prompt=False,
seq_data={
next(seq_id_counter):
create_sequence_data(num_generated=1),
},
sampling_params=create_sampling_params(1),
block_tables={},
)
]
}
stop_token_ids = [42, 99, 42, 0] # intentional duplication
simple_combination = {
"expected_penalization": [True, False, False],
"seq_group_metadata_list": [
SequenceGroupMetadata(
request_id="test_1",
is_prompt=False,
seq_data={
next(seq_id_counter):
create_sequence_data(num_generated=1),
next(seq_id_counter):
create_sequence_data(num_generated=100),
},
sampling_params=create_sampling_params(
2, stop_token_ids=stop_token_ids),
block_tables={},
),
SequenceGroupMetadata(
request_id="test_2",
is_prompt=True,
seq_data={
next(seq_id_counter): create_sequence_data(),
},
sampling_params=create_sampling_params(
0, stop_token_ids=stop_token_ids),
block_tables={},
)
]
}
if seed == 0:
test_cases = [
prompt_without_penalization,
prompt_with_penalization,
stop_penalizing_after_min_tokens,
simple_combination,
]
else:
test_cases = [generate_test_case()]
def run_test_case(*,
expected_penalization=None,
seq_group_metadata_list=None):
assert expected_penalization, "Invalid test case"
assert seq_group_metadata_list, "Invalid test case"
batch_size = 0
prompt_lens = []
sampling_params_per_seq = []
for sgm in seq_group_metadata_list:
num_seqs = len(sgm.seq_data)
batch_size += num_seqs
sampling_params = sgm.sampling_params
for seq_id in sgm.seq_data:
prompt_lens.append(sgm.seq_data[seq_id].get_prompt_len())
sampling_params_per_seq.append(sampling_params)
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
sampling_metadata = model_runner._prepare_sample(
seq_group_metadata_list,
prompt_lens=prompt_lens,
subquery_lens=prompt_lens)
# the logits tensor is modified in-place by the sampler
_ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
for logits_idx, (should_penalize, sampling_params) in enumerate(
zip(expected_penalization, sampling_params_per_seq)):
tokens_to_check = [sampling_params.eos_token_id]
if sampling_params.stop_token_ids:
tokens_to_check.extend(sampling_params.stop_token_ids)
tokens_to_check = set(tokens_to_check)
if should_penalize:
for token_id in tokens_to_check:
assert fake_logits[logits_idx, token_id] == -float(
'inf'
), f"Expected token {token_id} for logits row {logits_idx}"
" to be penalized"
# no other tokens should be set to -inf
assert torch.count_nonzero(
fake_logits[logits_idx, :] == -float('inf')) == len(
tokens_to_check
), f"Expected only {len(tokens_to_check)} to be penalized"
else:
# no tokens should be set to -inf
assert torch.count_nonzero(
fake_logits[logits_idx, :] ==
-float('inf')) == 0, "No tokens should have been penalized"
del model_runner
for test_case in test_cases:
run_test_case(**test_case)
@pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_mixed(seed: int, device: str): def test_sampler_mixed(seed: int, device: str):

View File

@ -282,6 +282,9 @@ class LLMEngine:
# Defensive copy of SamplingParams, which are used by the sampler, # Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects # this doesn't deep-copy LogitsProcessor objects
sampling_params = sampling_params.clone() sampling_params = sampling_params.clone()
# inject the eos token id into the sampling_params to support min_tokens
# processing
sampling_params.eos_token_id = seq.eos_token_id
# Create the sequence group. # Create the sequence group.
seq_group = SequenceGroup(request_id, [seq], sampling_params, seq_group = SequenceGroup(request_id, [seq], sampling_params,
@ -713,6 +716,21 @@ class LLMEngine:
def _check_stop(self, seq: Sequence, def _check_stop(self, seq: Sequence,
sampling_params: SamplingParams) -> None: sampling_params: SamplingParams) -> None:
"""Stop the finished sequences.""" """Stop the finished sequences."""
# Check if the sequence has reached max_model_len.
if seq.get_len() > self.scheduler_config.max_model_len:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
# Check if the sequence has reached max_tokens.
if seq.get_output_len() == sampling_params.max_tokens:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
# Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not
if seq.get_output_len() < sampling_params.min_tokens:
return
for stop_str in sampling_params.stop: for stop_str in sampling_params.stop:
if seq.output_text.endswith(stop_str): if seq.output_text.endswith(stop_str):
self._finalize_sequence(seq, sampling_params, stop_str) self._finalize_sequence(seq, sampling_params, stop_str)
@ -725,16 +743,6 @@ class LLMEngine:
seq.status = SequenceStatus.FINISHED_STOPPED seq.status = SequenceStatus.FINISHED_STOPPED
return return
# Check if the sequence has reached max_model_len.
if seq.get_len() > self.scheduler_config.max_model_len:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
# Check if the sequence has reached max_tokens.
if seq.get_output_len() == sampling_params.max_tokens:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
# Check if the sequence has generated the EOS token. # Check if the sequence has generated the EOS token.
if ((not sampling_params.ignore_eos) if ((not sampling_params.ignore_eos)
and seq.get_last_token_id() == seq.eos_token_id): and seq.get_last_token_id() == seq.eos_token_id):

View File

@ -88,6 +88,7 @@ class ChatCompletionRequest(BaseModel):
length_penalty: Optional[float] = 1.0 length_penalty: Optional[float] = 1.0
early_stopping: Optional[bool] = False early_stopping: Optional[bool] = False
ignore_eos: Optional[bool] = False ignore_eos: Optional[bool] = False
min_tokens: Optional[int] = 0
stop_token_ids: Optional[List[int]] = Field(default_factory=list) stop_token_ids: Optional[List[int]] = Field(default_factory=list)
skip_special_tokens: Optional[bool] = True skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True spaces_between_special_tokens: Optional[bool] = True
@ -165,6 +166,7 @@ class ChatCompletionRequest(BaseModel):
stop=self.stop, stop=self.stop,
stop_token_ids=self.stop_token_ids, stop_token_ids=self.stop_token_ids,
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
min_tokens=self.min_tokens,
logprobs=self.top_logprobs if self.logprobs else None, logprobs=self.top_logprobs if self.logprobs else None,
prompt_logprobs=self.top_logprobs if self.echo else None, prompt_logprobs=self.top_logprobs if self.echo else None,
best_of=self.best_of, best_of=self.best_of,
@ -224,6 +226,7 @@ class CompletionRequest(BaseModel):
early_stopping: Optional[bool] = False early_stopping: Optional[bool] = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list) stop_token_ids: Optional[List[int]] = Field(default_factory=list)
ignore_eos: Optional[bool] = False ignore_eos: Optional[bool] = False
min_tokens: Optional[int] = 0
skip_special_tokens: Optional[bool] = True skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True spaces_between_special_tokens: Optional[bool] = True
# doc: end-completion-sampling-params # doc: end-completion-sampling-params
@ -296,6 +299,7 @@ class CompletionRequest(BaseModel):
stop_token_ids=self.stop_token_ids, stop_token_ids=self.stop_token_ids,
ignore_eos=self.ignore_eos, ignore_eos=self.ignore_eos,
max_tokens=self.max_tokens if not echo_without_generation else 1, max_tokens=self.max_tokens if not echo_without_generation else 1,
min_tokens=self.min_tokens,
logprobs=self.logprobs, logprobs=self.logprobs,
use_beam_search=self.use_beam_search, use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping, early_stopping=self.early_stopping,

View File

@ -1,4 +1,5 @@
"""A layer that samples the next tokens from the model's outputs.""" """A layer that samples the next tokens from the model's outputs."""
import itertools
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import torch import torch
@ -36,6 +37,10 @@ class Sampler(nn.Module):
assert logits is not None assert logits is not None
_, vocab_size = logits.shape _, vocab_size = logits.shape
# Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
# have not been generated yet
logits = _apply_min_tokens_penalty(logits, sampling_metadata)
# Prepare sampling tensors with pinned memory to avoid blocking. # Prepare sampling tensors with pinned memory to avoid blocking.
(sampling_tensors, do_penalties, do_top_p_top_k, (sampling_tensors, do_penalties, do_top_p_top_k,
do_min_p) = SamplingTensors.from_sampling_metadata( do_min_p) = SamplingTensors.from_sampling_metadata(
@ -94,6 +99,42 @@ def _get_bin_counts_and_mask(
return bin_counts, mask return bin_counts, mask
def _apply_min_tokens_penalty(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
# list of indices in logits that will be set to -inf
logits_to_penalize = []
start_idx = 0
for seq_ids, sampling_params in sampling_metadata.seq_groups:
min_tokens = sampling_params.min_tokens
if min_tokens > 0:
seqs_to_penalize = []
for i, seq_id in enumerate(seq_ids):
seq_data = sampling_metadata.seq_data[seq_id]
if len(seq_data.output_token_ids) < min_tokens:
seqs_to_penalize.append(i)
if seqs_to_penalize:
# convert to the index into logits
seqs_to_penalize = [start_idx + i for i in seqs_to_penalize]
# use set() to remove any duplicates
token_ids_to_penalize = set(sampling_params.stop_token_ids +
[sampling_params.eos_token_id])
# itertools.product pairs each seq index with every token id
logits_to_penalize.extend(
itertools.product(seqs_to_penalize, token_ids_to_penalize))
start_idx += len(seq_ids)
if logits_to_penalize:
# use zip and * to group indices along each dimension
# eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
logits[tuple(zip(*logits_to_penalize))] = -float("inf")
return logits
def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
output_tokens_tensor: torch.Tensor, output_tokens_tensor: torch.Tensor,
presence_penalties: torch.Tensor, presence_penalties: torch.Tensor,

View File

@ -79,6 +79,8 @@ class SamplingParams:
ignore_eos: Whether to ignore the EOS token and continue generating ignore_eos: Whether to ignore the EOS token and continue generating
tokens after the EOS token is generated. tokens after the EOS token is generated.
max_tokens: Maximum number of tokens to generate per output sequence. max_tokens: Maximum number of tokens to generate per output sequence.
min_tokens: Minimum number of tokens to generate per output sequence
before EOS or stop_token_ids can be generated
logprobs: Number of log probabilities to return per output token. logprobs: Number of log probabilities to return per output token.
Note that the implementation follows the OpenAI API: The return Note that the implementation follows the OpenAI API: The return
result includes the log probabilities on the `logprobs` most likely result includes the log probabilities on the `logprobs` most likely
@ -113,6 +115,7 @@ class SamplingParams:
include_stop_str_in_output: bool = False, include_stop_str_in_output: bool = False,
ignore_eos: bool = False, ignore_eos: bool = False,
max_tokens: Optional[int] = 16, max_tokens: Optional[int] = 16,
min_tokens: int = 0,
logprobs: Optional[int] = None, logprobs: Optional[int] = None,
prompt_logprobs: Optional[int] = None, prompt_logprobs: Optional[int] = None,
skip_special_tokens: bool = True, skip_special_tokens: bool = True,
@ -144,6 +147,7 @@ class SamplingParams:
self.stop_token_ids = list(stop_token_ids) self.stop_token_ids = list(stop_token_ids)
self.ignore_eos = ignore_eos self.ignore_eos = ignore_eos
self.max_tokens = max_tokens self.max_tokens = max_tokens
self.min_tokens = min_tokens
self.logprobs = logprobs self.logprobs = logprobs
self.prompt_logprobs = prompt_logprobs self.prompt_logprobs = prompt_logprobs
self.skip_special_tokens = skip_special_tokens self.skip_special_tokens = skip_special_tokens
@ -161,6 +165,8 @@ class SamplingParams:
self.top_k = -1 self.top_k = -1
self.min_p = 0.0 self.min_p = 0.0
self._verify_greedy_sampling() self._verify_greedy_sampling()
# injected by the engine
self.eos_token_id = None
def _verify_args(self) -> None: def _verify_args(self) -> None:
if self.n < 1: if self.n < 1:
@ -191,6 +197,13 @@ class SamplingParams:
if self.max_tokens is not None and self.max_tokens < 1: if self.max_tokens is not None and self.max_tokens < 1:
raise ValueError( raise ValueError(
f"max_tokens must be at least 1, got {self.max_tokens}.") f"max_tokens must be at least 1, got {self.max_tokens}.")
if self.min_tokens < 0:
raise ValueError(f"min_tokens must be greater than or equal to 0, "
f"got {self.min_tokens}.")
if self.max_tokens is not None and self.min_tokens > self.max_tokens:
raise ValueError(
f"min_tokens must be less than or equal to "
f"max_tokens={self.max_tokens}, got {self.min_tokens}.")
if self.logprobs is not None and self.logprobs < 0: if self.logprobs is not None and self.logprobs < 0:
raise ValueError( raise ValueError(
f"logprobs must be non-negative, got {self.logprobs}.") f"logprobs must be non-negative, got {self.logprobs}.")
@ -272,6 +285,7 @@ class SamplingParams:
f"include_stop_str_in_output={self.include_stop_str_in_output}, " f"include_stop_str_in_output={self.include_stop_str_in_output}, "
f"ignore_eos={self.ignore_eos}, " f"ignore_eos={self.ignore_eos}, "
f"max_tokens={self.max_tokens}, " f"max_tokens={self.max_tokens}, "
f"min_tokens={self.min_tokens}, "
f"logprobs={self.logprobs}, " f"logprobs={self.logprobs}, "
f"prompt_logprobs={self.prompt_logprobs}, " f"prompt_logprobs={self.prompt_logprobs}, "
f"skip_special_tokens={self.skip_special_tokens}, " f"skip_special_tokens={self.skip_special_tokens}, "