mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 23:45:01 +08:00
feat: implement the min_tokens sampling parameter (#3124)
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com> Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
parent
819924e749
commit
c13ad1b7bd
@ -10,6 +10,7 @@ from vllm.model_executor.layers.sampler import Sampler
|
|||||||
from vllm.model_executor.utils import set_random_seed
|
from vllm.model_executor.utils import set_random_seed
|
||||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||||
from vllm.worker.model_runner import ModelRunner
|
from vllm.worker.model_runner import ModelRunner
|
||||||
|
from vllm.utils import Counter
|
||||||
|
|
||||||
|
|
||||||
class MockLogitsSampler(Sampler):
|
class MockLogitsSampler(Sampler):
|
||||||
@ -25,9 +26,8 @@ class MockLogitsSampler(Sampler):
|
|||||||
def _prepare_test(
|
def _prepare_test(
|
||||||
batch_size: int
|
batch_size: int
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, ModelRunner]:
|
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, ModelRunner]:
|
||||||
vocab_size = 32000
|
|
||||||
input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
|
input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
|
||||||
fake_logits = torch.full((batch_size, vocab_size),
|
fake_logits = torch.full((batch_size, VOCAB_SIZE),
|
||||||
1e-2,
|
1e-2,
|
||||||
dtype=input_tensor.dtype)
|
dtype=input_tensor.dtype)
|
||||||
sampler = MockLogitsSampler(fake_logits)
|
sampler = MockLogitsSampler(fake_logits)
|
||||||
@ -35,6 +35,7 @@ def _prepare_test(
|
|||||||
return input_tensor, fake_logits, sampler, model_runner
|
return input_tensor, fake_logits, sampler, model_runner
|
||||||
|
|
||||||
|
|
||||||
|
VOCAB_SIZE = 32000
|
||||||
RANDOM_SEEDS = list(range(128))
|
RANDOM_SEEDS = list(range(128))
|
||||||
CUDA_DEVICES = [
|
CUDA_DEVICES = [
|
||||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||||
@ -184,6 +185,225 @@ def test_sampler_all_beam(seed: int, device: str):
|
|||||||
del model_runner
|
del model_runner
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||||
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||||
|
seq_id_counter = Counter(start=random.randint(0, 100))
|
||||||
|
set_random_seed(seed)
|
||||||
|
torch.set_default_device(device)
|
||||||
|
|
||||||
|
def create_sampling_params(min_tokens,
|
||||||
|
eos_token_id=0,
|
||||||
|
stop_token_ids=None):
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
min_tokens=min_tokens,
|
||||||
|
max_tokens=9999, # keep higher than max of min_tokens
|
||||||
|
stop_token_ids=stop_token_ids,
|
||||||
|
)
|
||||||
|
sampling_params.eos_token_id = eos_token_id
|
||||||
|
return sampling_params
|
||||||
|
|
||||||
|
def create_sequence_data(num_input=3, num_generated=0):
|
||||||
|
seq_data = SequenceData(
|
||||||
|
random.choices(range(0, VOCAB_SIZE), k=num_input))
|
||||||
|
if num_generated > 0:
|
||||||
|
seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE),
|
||||||
|
k=num_generated)
|
||||||
|
return seq_data
|
||||||
|
|
||||||
|
def generate_test_case():
|
||||||
|
# generate multiple seq groups but limit total batch size
|
||||||
|
batch_size = random.randint(1, 128)
|
||||||
|
|
||||||
|
expected_penalization = []
|
||||||
|
sequence_metadata_list = []
|
||||||
|
while batch_size > 0:
|
||||||
|
# 20% chance to generate prompt seq group with single sequence
|
||||||
|
is_prompt = random.random() < 0.2
|
||||||
|
num_seqs = 1 if is_prompt else random.randint(1, batch_size)
|
||||||
|
|
||||||
|
eos_token_id = random.randint(0, VOCAB_SIZE - 1)
|
||||||
|
min_tokens = random.randint(0, 50)
|
||||||
|
num_stop_tokens = random.randint(0, 8)
|
||||||
|
if num_stop_tokens > 0:
|
||||||
|
stop_token_ids = random.choices(range(0, VOCAB_SIZE - 1),
|
||||||
|
k=num_stop_tokens)
|
||||||
|
else:
|
||||||
|
stop_token_ids = None
|
||||||
|
|
||||||
|
sampling_params = create_sampling_params(
|
||||||
|
min_tokens=min_tokens,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
stop_token_ids=stop_token_ids)
|
||||||
|
|
||||||
|
seq_data = {}
|
||||||
|
seq_group_penalization = []
|
||||||
|
for _ in range(num_seqs):
|
||||||
|
num_input = random.randint(1, 100)
|
||||||
|
num_generated = random.randint(1, 100) if not is_prompt else 0
|
||||||
|
seq_data[next(seq_id_counter)] = create_sequence_data(
|
||||||
|
num_input=num_input, num_generated=num_generated)
|
||||||
|
seq_group_penalization.append(num_generated < min_tokens)
|
||||||
|
|
||||||
|
expected_penalization.extend(seq_group_penalization)
|
||||||
|
sequence_metadata_list.append(
|
||||||
|
SequenceGroupMetadata(
|
||||||
|
request_id=f"test_{batch_size}",
|
||||||
|
is_prompt=is_prompt,
|
||||||
|
seq_data=seq_data,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
block_tables={},
|
||||||
|
))
|
||||||
|
batch_size -= num_seqs
|
||||||
|
|
||||||
|
return {
|
||||||
|
"expected_penalization": expected_penalization,
|
||||||
|
"seq_group_metadata_list": sequence_metadata_list,
|
||||||
|
}
|
||||||
|
|
||||||
|
# define some explicit test cases for edge case behavior
|
||||||
|
prompt_without_penalization = {
|
||||||
|
"expected_penalization": [False],
|
||||||
|
"seq_group_metadata_list": [
|
||||||
|
SequenceGroupMetadata(
|
||||||
|
request_id="test_1",
|
||||||
|
is_prompt=True,
|
||||||
|
seq_data={
|
||||||
|
next(seq_id_counter): create_sequence_data(),
|
||||||
|
},
|
||||||
|
sampling_params=create_sampling_params(0),
|
||||||
|
block_tables={},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt_with_penalization = {
|
||||||
|
"expected_penalization": [True],
|
||||||
|
"seq_group_metadata_list": [
|
||||||
|
SequenceGroupMetadata(
|
||||||
|
request_id="test_1",
|
||||||
|
is_prompt=True,
|
||||||
|
seq_data={
|
||||||
|
next(seq_id_counter): create_sequence_data(),
|
||||||
|
},
|
||||||
|
sampling_params=create_sampling_params(1),
|
||||||
|
block_tables={},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
stop_penalizing_after_min_tokens = {
|
||||||
|
"expected_penalization": [False],
|
||||||
|
"seq_group_metadata_list": [
|
||||||
|
SequenceGroupMetadata(
|
||||||
|
request_id="test_1",
|
||||||
|
is_prompt=False,
|
||||||
|
seq_data={
|
||||||
|
next(seq_id_counter):
|
||||||
|
create_sequence_data(num_generated=1),
|
||||||
|
},
|
||||||
|
sampling_params=create_sampling_params(1),
|
||||||
|
block_tables={},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
stop_token_ids = [42, 99, 42, 0] # intentional duplication
|
||||||
|
simple_combination = {
|
||||||
|
"expected_penalization": [True, False, False],
|
||||||
|
"seq_group_metadata_list": [
|
||||||
|
SequenceGroupMetadata(
|
||||||
|
request_id="test_1",
|
||||||
|
is_prompt=False,
|
||||||
|
seq_data={
|
||||||
|
next(seq_id_counter):
|
||||||
|
create_sequence_data(num_generated=1),
|
||||||
|
next(seq_id_counter):
|
||||||
|
create_sequence_data(num_generated=100),
|
||||||
|
},
|
||||||
|
sampling_params=create_sampling_params(
|
||||||
|
2, stop_token_ids=stop_token_ids),
|
||||||
|
block_tables={},
|
||||||
|
),
|
||||||
|
SequenceGroupMetadata(
|
||||||
|
request_id="test_2",
|
||||||
|
is_prompt=True,
|
||||||
|
seq_data={
|
||||||
|
next(seq_id_counter): create_sequence_data(),
|
||||||
|
},
|
||||||
|
sampling_params=create_sampling_params(
|
||||||
|
0, stop_token_ids=stop_token_ids),
|
||||||
|
block_tables={},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
if seed == 0:
|
||||||
|
test_cases = [
|
||||||
|
prompt_without_penalization,
|
||||||
|
prompt_with_penalization,
|
||||||
|
stop_penalizing_after_min_tokens,
|
||||||
|
simple_combination,
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
test_cases = [generate_test_case()]
|
||||||
|
|
||||||
|
def run_test_case(*,
|
||||||
|
expected_penalization=None,
|
||||||
|
seq_group_metadata_list=None):
|
||||||
|
assert expected_penalization, "Invalid test case"
|
||||||
|
assert seq_group_metadata_list, "Invalid test case"
|
||||||
|
|
||||||
|
batch_size = 0
|
||||||
|
prompt_lens = []
|
||||||
|
sampling_params_per_seq = []
|
||||||
|
for sgm in seq_group_metadata_list:
|
||||||
|
num_seqs = len(sgm.seq_data)
|
||||||
|
batch_size += num_seqs
|
||||||
|
sampling_params = sgm.sampling_params
|
||||||
|
for seq_id in sgm.seq_data:
|
||||||
|
prompt_lens.append(sgm.seq_data[seq_id].get_prompt_len())
|
||||||
|
sampling_params_per_seq.append(sampling_params)
|
||||||
|
|
||||||
|
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
|
||||||
|
sampling_metadata = model_runner._prepare_sample(
|
||||||
|
seq_group_metadata_list,
|
||||||
|
prompt_lens=prompt_lens,
|
||||||
|
subquery_lens=prompt_lens)
|
||||||
|
# the logits tensor is modified in-place by the sampler
|
||||||
|
_ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
|
||||||
|
|
||||||
|
for logits_idx, (should_penalize, sampling_params) in enumerate(
|
||||||
|
zip(expected_penalization, sampling_params_per_seq)):
|
||||||
|
|
||||||
|
tokens_to_check = [sampling_params.eos_token_id]
|
||||||
|
if sampling_params.stop_token_ids:
|
||||||
|
tokens_to_check.extend(sampling_params.stop_token_ids)
|
||||||
|
tokens_to_check = set(tokens_to_check)
|
||||||
|
|
||||||
|
if should_penalize:
|
||||||
|
for token_id in tokens_to_check:
|
||||||
|
assert fake_logits[logits_idx, token_id] == -float(
|
||||||
|
'inf'
|
||||||
|
), f"Expected token {token_id} for logits row {logits_idx}"
|
||||||
|
" to be penalized"
|
||||||
|
# no other tokens should be set to -inf
|
||||||
|
assert torch.count_nonzero(
|
||||||
|
fake_logits[logits_idx, :] == -float('inf')) == len(
|
||||||
|
tokens_to_check
|
||||||
|
), f"Expected only {len(tokens_to_check)} to be penalized"
|
||||||
|
else:
|
||||||
|
# no tokens should be set to -inf
|
||||||
|
assert torch.count_nonzero(
|
||||||
|
fake_logits[logits_idx, :] ==
|
||||||
|
-float('inf')) == 0, "No tokens should have been penalized"
|
||||||
|
|
||||||
|
del model_runner
|
||||||
|
|
||||||
|
for test_case in test_cases:
|
||||||
|
run_test_case(**test_case)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
def test_sampler_mixed(seed: int, device: str):
|
def test_sampler_mixed(seed: int, device: str):
|
||||||
|
|||||||
@ -282,6 +282,9 @@ class LLMEngine:
|
|||||||
# Defensive copy of SamplingParams, which are used by the sampler,
|
# Defensive copy of SamplingParams, which are used by the sampler,
|
||||||
# this doesn't deep-copy LogitsProcessor objects
|
# this doesn't deep-copy LogitsProcessor objects
|
||||||
sampling_params = sampling_params.clone()
|
sampling_params = sampling_params.clone()
|
||||||
|
# inject the eos token id into the sampling_params to support min_tokens
|
||||||
|
# processing
|
||||||
|
sampling_params.eos_token_id = seq.eos_token_id
|
||||||
|
|
||||||
# Create the sequence group.
|
# Create the sequence group.
|
||||||
seq_group = SequenceGroup(request_id, [seq], sampling_params,
|
seq_group = SequenceGroup(request_id, [seq], sampling_params,
|
||||||
@ -713,6 +716,21 @@ class LLMEngine:
|
|||||||
def _check_stop(self, seq: Sequence,
|
def _check_stop(self, seq: Sequence,
|
||||||
sampling_params: SamplingParams) -> None:
|
sampling_params: SamplingParams) -> None:
|
||||||
"""Stop the finished sequences."""
|
"""Stop the finished sequences."""
|
||||||
|
# Check if the sequence has reached max_model_len.
|
||||||
|
if seq.get_len() > self.scheduler_config.max_model_len:
|
||||||
|
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if the sequence has reached max_tokens.
|
||||||
|
if seq.get_output_len() == sampling_params.max_tokens:
|
||||||
|
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if the minimum number of tokens has been generated yet;
|
||||||
|
# skip the stop string/token checks if not
|
||||||
|
if seq.get_output_len() < sampling_params.min_tokens:
|
||||||
|
return
|
||||||
|
|
||||||
for stop_str in sampling_params.stop:
|
for stop_str in sampling_params.stop:
|
||||||
if seq.output_text.endswith(stop_str):
|
if seq.output_text.endswith(stop_str):
|
||||||
self._finalize_sequence(seq, sampling_params, stop_str)
|
self._finalize_sequence(seq, sampling_params, stop_str)
|
||||||
@ -725,16 +743,6 @@ class LLMEngine:
|
|||||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||||
return
|
return
|
||||||
|
|
||||||
# Check if the sequence has reached max_model_len.
|
|
||||||
if seq.get_len() > self.scheduler_config.max_model_len:
|
|
||||||
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check if the sequence has reached max_tokens.
|
|
||||||
if seq.get_output_len() == sampling_params.max_tokens:
|
|
||||||
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check if the sequence has generated the EOS token.
|
# Check if the sequence has generated the EOS token.
|
||||||
if ((not sampling_params.ignore_eos)
|
if ((not sampling_params.ignore_eos)
|
||||||
and seq.get_last_token_id() == seq.eos_token_id):
|
and seq.get_last_token_id() == seq.eos_token_id):
|
||||||
|
|||||||
@ -88,6 +88,7 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
length_penalty: Optional[float] = 1.0
|
length_penalty: Optional[float] = 1.0
|
||||||
early_stopping: Optional[bool] = False
|
early_stopping: Optional[bool] = False
|
||||||
ignore_eos: Optional[bool] = False
|
ignore_eos: Optional[bool] = False
|
||||||
|
min_tokens: Optional[int] = 0
|
||||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||||
skip_special_tokens: Optional[bool] = True
|
skip_special_tokens: Optional[bool] = True
|
||||||
spaces_between_special_tokens: Optional[bool] = True
|
spaces_between_special_tokens: Optional[bool] = True
|
||||||
@ -165,6 +166,7 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
stop=self.stop,
|
stop=self.stop,
|
||||||
stop_token_ids=self.stop_token_ids,
|
stop_token_ids=self.stop_token_ids,
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
|
min_tokens=self.min_tokens,
|
||||||
logprobs=self.top_logprobs if self.logprobs else None,
|
logprobs=self.top_logprobs if self.logprobs else None,
|
||||||
prompt_logprobs=self.top_logprobs if self.echo else None,
|
prompt_logprobs=self.top_logprobs if self.echo else None,
|
||||||
best_of=self.best_of,
|
best_of=self.best_of,
|
||||||
@ -224,6 +226,7 @@ class CompletionRequest(BaseModel):
|
|||||||
early_stopping: Optional[bool] = False
|
early_stopping: Optional[bool] = False
|
||||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||||
ignore_eos: Optional[bool] = False
|
ignore_eos: Optional[bool] = False
|
||||||
|
min_tokens: Optional[int] = 0
|
||||||
skip_special_tokens: Optional[bool] = True
|
skip_special_tokens: Optional[bool] = True
|
||||||
spaces_between_special_tokens: Optional[bool] = True
|
spaces_between_special_tokens: Optional[bool] = True
|
||||||
# doc: end-completion-sampling-params
|
# doc: end-completion-sampling-params
|
||||||
@ -296,6 +299,7 @@ class CompletionRequest(BaseModel):
|
|||||||
stop_token_ids=self.stop_token_ids,
|
stop_token_ids=self.stop_token_ids,
|
||||||
ignore_eos=self.ignore_eos,
|
ignore_eos=self.ignore_eos,
|
||||||
max_tokens=self.max_tokens if not echo_without_generation else 1,
|
max_tokens=self.max_tokens if not echo_without_generation else 1,
|
||||||
|
min_tokens=self.min_tokens,
|
||||||
logprobs=self.logprobs,
|
logprobs=self.logprobs,
|
||||||
use_beam_search=self.use_beam_search,
|
use_beam_search=self.use_beam_search,
|
||||||
early_stopping=self.early_stopping,
|
early_stopping=self.early_stopping,
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
"""A layer that samples the next tokens from the model's outputs."""
|
"""A layer that samples the next tokens from the model's outputs."""
|
||||||
|
import itertools
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -36,6 +37,10 @@ class Sampler(nn.Module):
|
|||||||
assert logits is not None
|
assert logits is not None
|
||||||
_, vocab_size = logits.shape
|
_, vocab_size = logits.shape
|
||||||
|
|
||||||
|
# Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
|
||||||
|
# have not been generated yet
|
||||||
|
logits = _apply_min_tokens_penalty(logits, sampling_metadata)
|
||||||
|
|
||||||
# Prepare sampling tensors with pinned memory to avoid blocking.
|
# Prepare sampling tensors with pinned memory to avoid blocking.
|
||||||
(sampling_tensors, do_penalties, do_top_p_top_k,
|
(sampling_tensors, do_penalties, do_top_p_top_k,
|
||||||
do_min_p) = SamplingTensors.from_sampling_metadata(
|
do_min_p) = SamplingTensors.from_sampling_metadata(
|
||||||
@ -94,6 +99,42 @@ def _get_bin_counts_and_mask(
|
|||||||
return bin_counts, mask
|
return bin_counts, mask
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_min_tokens_penalty(
|
||||||
|
logits: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# list of indices in logits that will be set to -inf
|
||||||
|
logits_to_penalize = []
|
||||||
|
start_idx = 0
|
||||||
|
for seq_ids, sampling_params in sampling_metadata.seq_groups:
|
||||||
|
min_tokens = sampling_params.min_tokens
|
||||||
|
if min_tokens > 0:
|
||||||
|
seqs_to_penalize = []
|
||||||
|
for i, seq_id in enumerate(seq_ids):
|
||||||
|
seq_data = sampling_metadata.seq_data[seq_id]
|
||||||
|
if len(seq_data.output_token_ids) < min_tokens:
|
||||||
|
seqs_to_penalize.append(i)
|
||||||
|
|
||||||
|
if seqs_to_penalize:
|
||||||
|
# convert to the index into logits
|
||||||
|
seqs_to_penalize = [start_idx + i for i in seqs_to_penalize]
|
||||||
|
# use set() to remove any duplicates
|
||||||
|
token_ids_to_penalize = set(sampling_params.stop_token_ids +
|
||||||
|
[sampling_params.eos_token_id])
|
||||||
|
# itertools.product pairs each seq index with every token id
|
||||||
|
logits_to_penalize.extend(
|
||||||
|
itertools.product(seqs_to_penalize, token_ids_to_penalize))
|
||||||
|
|
||||||
|
start_idx += len(seq_ids)
|
||||||
|
|
||||||
|
if logits_to_penalize:
|
||||||
|
# use zip and * to group indices along each dimension
|
||||||
|
# eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
|
||||||
|
logits[tuple(zip(*logits_to_penalize))] = -float("inf")
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
|
def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
|
||||||
output_tokens_tensor: torch.Tensor,
|
output_tokens_tensor: torch.Tensor,
|
||||||
presence_penalties: torch.Tensor,
|
presence_penalties: torch.Tensor,
|
||||||
|
|||||||
@ -79,6 +79,8 @@ class SamplingParams:
|
|||||||
ignore_eos: Whether to ignore the EOS token and continue generating
|
ignore_eos: Whether to ignore the EOS token and continue generating
|
||||||
tokens after the EOS token is generated.
|
tokens after the EOS token is generated.
|
||||||
max_tokens: Maximum number of tokens to generate per output sequence.
|
max_tokens: Maximum number of tokens to generate per output sequence.
|
||||||
|
min_tokens: Minimum number of tokens to generate per output sequence
|
||||||
|
before EOS or stop_token_ids can be generated
|
||||||
logprobs: Number of log probabilities to return per output token.
|
logprobs: Number of log probabilities to return per output token.
|
||||||
Note that the implementation follows the OpenAI API: The return
|
Note that the implementation follows the OpenAI API: The return
|
||||||
result includes the log probabilities on the `logprobs` most likely
|
result includes the log probabilities on the `logprobs` most likely
|
||||||
@ -113,6 +115,7 @@ class SamplingParams:
|
|||||||
include_stop_str_in_output: bool = False,
|
include_stop_str_in_output: bool = False,
|
||||||
ignore_eos: bool = False,
|
ignore_eos: bool = False,
|
||||||
max_tokens: Optional[int] = 16,
|
max_tokens: Optional[int] = 16,
|
||||||
|
min_tokens: int = 0,
|
||||||
logprobs: Optional[int] = None,
|
logprobs: Optional[int] = None,
|
||||||
prompt_logprobs: Optional[int] = None,
|
prompt_logprobs: Optional[int] = None,
|
||||||
skip_special_tokens: bool = True,
|
skip_special_tokens: bool = True,
|
||||||
@ -144,6 +147,7 @@ class SamplingParams:
|
|||||||
self.stop_token_ids = list(stop_token_ids)
|
self.stop_token_ids = list(stop_token_ids)
|
||||||
self.ignore_eos = ignore_eos
|
self.ignore_eos = ignore_eos
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
|
self.min_tokens = min_tokens
|
||||||
self.logprobs = logprobs
|
self.logprobs = logprobs
|
||||||
self.prompt_logprobs = prompt_logprobs
|
self.prompt_logprobs = prompt_logprobs
|
||||||
self.skip_special_tokens = skip_special_tokens
|
self.skip_special_tokens = skip_special_tokens
|
||||||
@ -161,6 +165,8 @@ class SamplingParams:
|
|||||||
self.top_k = -1
|
self.top_k = -1
|
||||||
self.min_p = 0.0
|
self.min_p = 0.0
|
||||||
self._verify_greedy_sampling()
|
self._verify_greedy_sampling()
|
||||||
|
# injected by the engine
|
||||||
|
self.eos_token_id = None
|
||||||
|
|
||||||
def _verify_args(self) -> None:
|
def _verify_args(self) -> None:
|
||||||
if self.n < 1:
|
if self.n < 1:
|
||||||
@ -191,6 +197,13 @@ class SamplingParams:
|
|||||||
if self.max_tokens is not None and self.max_tokens < 1:
|
if self.max_tokens is not None and self.max_tokens < 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"max_tokens must be at least 1, got {self.max_tokens}.")
|
f"max_tokens must be at least 1, got {self.max_tokens}.")
|
||||||
|
if self.min_tokens < 0:
|
||||||
|
raise ValueError(f"min_tokens must be greater than or equal to 0, "
|
||||||
|
f"got {self.min_tokens}.")
|
||||||
|
if self.max_tokens is not None and self.min_tokens > self.max_tokens:
|
||||||
|
raise ValueError(
|
||||||
|
f"min_tokens must be less than or equal to "
|
||||||
|
f"max_tokens={self.max_tokens}, got {self.min_tokens}.")
|
||||||
if self.logprobs is not None and self.logprobs < 0:
|
if self.logprobs is not None and self.logprobs < 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"logprobs must be non-negative, got {self.logprobs}.")
|
f"logprobs must be non-negative, got {self.logprobs}.")
|
||||||
@ -272,6 +285,7 @@ class SamplingParams:
|
|||||||
f"include_stop_str_in_output={self.include_stop_str_in_output}, "
|
f"include_stop_str_in_output={self.include_stop_str_in_output}, "
|
||||||
f"ignore_eos={self.ignore_eos}, "
|
f"ignore_eos={self.ignore_eos}, "
|
||||||
f"max_tokens={self.max_tokens}, "
|
f"max_tokens={self.max_tokens}, "
|
||||||
|
f"min_tokens={self.min_tokens}, "
|
||||||
f"logprobs={self.logprobs}, "
|
f"logprobs={self.logprobs}, "
|
||||||
f"prompt_logprobs={self.prompt_logprobs}, "
|
f"prompt_logprobs={self.prompt_logprobs}, "
|
||||||
f"skip_special_tokens={self.skip_special_tokens}, "
|
f"skip_special_tokens={self.skip_special_tokens}, "
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user