Support per-request seed (#2514)

This commit is contained in:
Nick Hill 2024-02-21 11:47:00 -08:00 committed by GitHub
parent dc903e70ac
commit 7d2dcce175
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 296 additions and 91 deletions

View File

@ -1,10 +1,11 @@
import random import random
from typing import Tuple from typing import Tuple, List
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
import torch import torch
from transformers import GenerationConfig, GenerationMixin from transformers import GenerationConfig, GenerationMixin
from typing import Optional
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
@ -46,6 +47,34 @@ CUDA_DEVICES = [
] ]
def _do_sample(
batch_size: int,
input_tensor: torch.Tensor,
sampler: MockLogitsSampler,
model_runner: ModelRunner,
sampling_params: SamplingParams,
):
seq_group_metadata_list = []
prompt_lens = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=sampling_params,
block_tables={0: [1]},
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
return sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
@pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_greedy(seed: int, device: str): def test_sampler_all_greedy(seed: int, device: str):
@ -55,25 +84,9 @@ def test_sampler_all_greedy(seed: int, device: str):
input_tensor, fake_logits, sampler, model_runner = _prepare_test( input_tensor, fake_logits, sampler, model_runner = _prepare_test(
batch_size) batch_size)
seq_group_metadata_list = [] sampling_params = SamplingParams(temperature=0)
prompt_lens = [] sampler_output = _do_sample(batch_size, input_tensor, sampler,
for i in range(batch_size): model_runner, sampling_params)
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(temperature=0, ),
block_tables={0: [1]},
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
expected = torch.argmax(fake_logits, dim=-1) expected = torch.argmax(fake_logits, dim=-1)
for i, sequence_output in enumerate(sampler_output): for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples: for nth_output in sequence_output.samples:
@ -94,28 +107,13 @@ def test_sampler_all_random(seed: int, device: str):
for i in range(batch_size): for i in range(batch_size):
fake_logits[i, i] = 1e2 fake_logits[i, i] = 1e2
seq_group_metadata_list = [] sampling_params = SamplingParams(
prompt_lens = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(
temperature=1.0, temperature=1.0,
n=random.randint(1, 10), n=random.randint(1, 10),
), )
block_tables={0: [1]}, sampler_output = _do_sample(batch_size, input_tensor, sampler,
)) model_runner, sampling_params)
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
for i, sequence_output in enumerate(sampler_output): for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples: for nth_output in sequence_output.samples:
assert nth_output.output_token == i assert nth_output.output_token == i
@ -123,6 +121,58 @@ def test_sampler_all_random(seed: int, device: str):
del model_runner del model_runner
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random_seed(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
batch_size)
for i in range(batch_size):
fake_logits[i, i] = 1e2
sampling_params = SamplingParams(
temperature=1.0,
n=random.randint(1, 10),
seed=random.randint(0, 10000),
)
sampler_output = _do_sample(batch_size, input_tensor, sampler,
model_runner, sampling_params)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
assert nth_output.output_token == i
del model_runner
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random_seed_deterministic(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
batch_size)
sampling_params = SamplingParams(
temperature=1.0,
n=random.randint(1, 10),
seed=random.randint(0, 10000),
)
first_sampler_output = _do_sample(batch_size, input_tensor, sampler,
model_runner, sampling_params)
second_sampler_output = _do_sample(batch_size, input_tensor, sampler,
model_runner, sampling_params)
assert first_sampler_output == second_sampler_output
del model_runner
@pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_beam(seed: int, device: str): def test_sampler_all_beam(seed: int, device: str):
@ -131,29 +181,13 @@ def test_sampler_all_beam(seed: int, device: str):
batch_size = random.randint(1, 256) batch_size = random.randint(1, 256)
input_tensor, _, sampler, model_runner = _prepare_test(batch_size) input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
seq_group_metadata_list = [] sampling_params = SamplingParams(
prompt_lens = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(
temperature=0, temperature=0,
best_of=2, best_of=2,
use_beam_search=True, use_beam_search=True,
), )
block_tables={0: [1]}, _do_sample(batch_size, input_tensor, sampler, model_runner,
)) sampling_params)
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
# no assertion here as I am not sure how to determine whether # no assertion here as I am not sure how to determine whether
# the outputs are expected - in other words, this just tests # the outputs are expected - in other words, this just tests
# whether there are no exceptions in the sampler # whether there are no exceptions in the sampler
@ -171,14 +205,15 @@ def test_sampler_mixed(seed: int, device: str):
batch_size) batch_size)
seq_group_metadata_list = [] seq_group_metadata_list = []
expected_tokens = [] expected_tokens: List[Optional[List[int]]] = []
prompt_lens = [] prompt_lens = []
for i in range(batch_size): for i in range(batch_size):
n = 1 expected: Optional[List[int]] = None
sampling_type = random.randint(0, 2) sampling_type = random.randint(0, 3)
if sampling_type == 0: if sampling_type == 0:
sampling_params = SamplingParams(temperature=0) sampling_params = SamplingParams(temperature=0)
elif sampling_type == 1: expected = [torch.argmax(fake_logits[i], dim=-1).item()]
elif sampling_type in (1, 2):
n = random.randint(1, 10) n = random.randint(1, 10)
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=random.random() + 0.1, temperature=random.random() + 0.1,
@ -187,13 +222,17 @@ def test_sampler_mixed(seed: int, device: str):
n=n, n=n,
presence_penalty=random.randint(0, 1), presence_penalty=random.randint(0, 1),
) )
if sampling_type == 2:
sampling_params.seed = random.randint(0, 10000)
else:
for idx in range(n):
fake_logits[i, i + idx] = 1e2
expected = list(range(i, i + n))
else: else:
sampling_params = SamplingParams(temperature=0, sampling_params = SamplingParams(temperature=0,
use_beam_search=True, use_beam_search=True,
best_of=2) best_of=2)
for idx in range(n): expected_tokens.append(expected)
fake_logits[i, i + idx] = 1e2
expected_tokens.append(i + idx)
seq_group_metadata_list.append( seq_group_metadata_list.append(
SequenceGroupMetadata( SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
@ -204,17 +243,50 @@ def test_sampler_mixed(seed: int, device: str):
)) ))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, def test_sampling(model_runner: ModelRunner):
prompt_lens, sampling_metadata = model_runner._prepare_sample(
subquery_lens=prompt_lens) seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None, sampler_output = sampler(embedding=None,
hidden_states=input_tensor, hidden_states=input_tensor,
sampling_metadata=sampling_metadata) sampling_metadata=sampling_metadata)
for i, sequence_output in enumerate(sampler_output):
if seq_group_metadata_list[i].sampling_params.use_beam_search: for i, (sequence_output, metadata) in enumerate(
zip(sampler_output, seq_group_metadata_list)):
if metadata.sampling_params.use_beam_search:
continue continue
for nth_output in sequence_output.samples:
assert nth_output.output_token in expected_tokens if metadata.sampling_params.seed is not None \
and expected_tokens[i] is None:
# Record seeded random result to compare with results of second invocation
expected_tokens[i] = [
nth_output.output_token
for nth_output in sequence_output.samples
]
continue
for n, nth_output in enumerate(sequence_output.samples):
if metadata.sampling_params.temperature == 0 or metadata.sampling_params.seed is not None:
# Ensure exact matches for greedy or random with seed
assert nth_output.output_token == expected_tokens[i][n]
else:
# For non-seeded random check that one of the high-logit tokens were chosen
assert nth_output.output_token in expected_tokens[i]
# Test batch
test_sampling(model_runner)
# Shuffle the batch and resample
target_index = list(range(batch_size))
for list_to_shuffle in (target_index, seq_group_metadata_list,
expected_tokens, prompt_lens):
random.Random(seed).shuffle(list_to_shuffle)
target_index = torch.tensor(target_index)
input_tensor.data = input_tensor.index_select(0, target_index)
fake_logits.data = fake_logits.index_select(0, target_index)
# This time, results of seeded random samples will be compared with the corresponding
# sample in the pre-shuffled batch
test_sampling(model_runner)
del model_runner del model_runner

View File

@ -0,0 +1,82 @@
"""Verify that seeded random sampling is deterministic.
Run `pytest tests/samplers/test_seeded_generate.py --forked`.
"""
import copy
import random
from itertools import combinations
import pytest
from vllm.model_executor.utils import set_random_seed
from vllm import SamplingParams
MODEL = "facebook/opt-125m"
RANDOM_SEEDS = list(range(5))
@pytest.fixture
def vllm_model(vllm_runner):
vllm_model = vllm_runner(MODEL, dtype="half")
yield vllm_model
del vllm_model
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_random_sample_with_seed(
vllm_model,
example_prompts,
seed: int,
) -> None:
set_random_seed(seed)
sampling_params = SamplingParams(
# Parameters to ensure sufficient randomness
temperature=2.0,
top_p=min(random.random() + 0.3, 1),
top_k=random.randint(5, 20),
n=random.randint(1, 10),
presence_penalty=random.randint(0, 1),
max_tokens=8,
ignore_eos=True,
)
sampling_params_seed_1 = copy.deepcopy(sampling_params)
sampling_params_seed_1.seed = 100
sampling_params_seed_2 = copy.deepcopy(sampling_params)
sampling_params_seed_2.seed = 200
llm = vllm_model.model
for prompt in example_prompts:
for params in (
sampling_params,
sampling_params_seed_1,
sampling_params_seed_2,
sampling_params,
sampling_params_seed_1,
sampling_params_seed_2,
):
llm._add_request(
prompt=prompt,
prompt_token_ids=None,
sampling_params=params,
)
results = llm._run_engine(use_tqdm=False)
all_outputs = [[out.token_ids for out in output.outputs]
for output in results]
for i in range(0, len(example_prompts), 6):
outputs = all_outputs[i:i + 6]
# verify all non-seeded requests differ
for output_a, output_b in combinations(
(outputs[0], outputs[1], outputs[2], outputs[3]),
2,
):
assert output_a != output_b
# verify requests with the same seed match
assert outputs[1] == outputs[4]
assert outputs[2] == outputs[5]

View File

@ -387,6 +387,7 @@ class Scheduler:
block_tables=block_tables, block_tables=block_tables,
lora_request=seq_group.lora_request, lora_request=seq_group.lora_request,
prefix=seq_group.prefix, prefix=seq_group.prefix,
state=seq_group.state,
) )
seq_group_metadata_list.append(seq_group_metadata) seq_group_metadata_list.append(seq_group_metadata)
return seq_group_metadata_list, scheduler_outputs return seq_group_metadata_list, scheduler_outputs

View File

@ -173,7 +173,6 @@ class EngineArgs:
default=EngineArgs.block_size, default=EngineArgs.block_size,
choices=[8, 16, 32], choices=[8, 16, 32],
help='token block size') help='token block size')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed', parser.add_argument('--seed',
type=int, type=int,
default=EngineArgs.seed, default=EngineArgs.seed,

View File

@ -60,6 +60,7 @@ class ChatCompletionRequest(BaseModel):
top_p: Optional[float] = 1.0 top_p: Optional[float] = 1.0
n: Optional[int] = 1 n: Optional[int] = 1
max_tokens: Optional[int] = None max_tokens: Optional[int] = None
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False stream: Optional[bool] = False
presence_penalty: Optional[float] = 0.0 presence_penalty: Optional[float] = 0.0
@ -90,6 +91,7 @@ class ChatCompletionRequest(BaseModel):
temperature=self.temperature, temperature=self.temperature,
top_p=self.top_p, top_p=self.top_p,
min_p=self.min_p, min_p=self.min_p,
seed=self.seed,
stop=self.stop, stop=self.stop,
stop_token_ids=self.stop_token_ids, stop_token_ids=self.stop_token_ids,
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
@ -117,6 +119,7 @@ class CompletionRequest(BaseModel):
logprobs: Optional[int] = None logprobs: Optional[int] = None
echo: Optional[bool] = False echo: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
seed: Optional[int] = None
presence_penalty: Optional[float] = 0.0 presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0
best_of: Optional[int] = None best_of: Optional[int] = None
@ -147,6 +150,7 @@ class CompletionRequest(BaseModel):
top_p=self.top_p, top_p=self.top_p,
top_k=self.top_k, top_k=self.top_k,
min_p=self.min_p, min_p=self.min_p,
seed=self.seed,
stop=self.stop, stop=self.stop,
stop_token_ids=self.stop_token_ids, stop_token_ids=self.stop_token_ids,
ignore_eos=self.ignore_eos, ignore_eos=self.ignore_eos,

View File

@ -342,7 +342,9 @@ def _beam_search_sample(
def _multinomial( def _multinomial(
probs: torch.Tensor, probs: torch.Tensor,
num_samples: int, num_samples: int,
): seq_groups: Optional[List[Tuple[List[int], SamplingParams]]] = None,
generators: Optional[List[torch.Generator]] = None,
) -> torch.Tensor:
if num_samples > 1: if num_samples > 1:
# This is equivalent to torch.repeat_interleaved (which also # This is equivalent to torch.repeat_interleaved (which also
# forces a GPU<->CPU sync). # forces a GPU<->CPU sync).
@ -352,7 +354,15 @@ def _multinomial(
probs = probs[:, None, :].expand(probs.shape[0], num_samples, probs = probs[:, None, :].expand(probs.shape[0], num_samples,
probs.shape[1]).contiguous().view( probs.shape[1]).contiguous().view(
-1, probs.shape[1]) -1, probs.shape[1])
q = torch.empty_like(probs).exponential_(1) q = torch.empty_like(probs)
if seq_groups is None:
q.exponential_()
else:
sample_idx = 0
for (seq_ids, _), generator in zip(seq_groups, generators):
next_sample_idx = sample_idx + len(seq_ids) * num_samples
q[sample_idx:next_sample_idx].exponential_(generator=generator)
sample_idx = next_sample_idx
return probs.div_(q).argmax(dim=1).view(-1, num_samples) return probs.div_(q).argmax(dim=1).view(-1, num_samples)
@ -370,6 +380,7 @@ def _sample(
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {} sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
sample_metadata = {} sample_metadata = {}
multinomial_samples = {}
# Counterintiutively, having two loops here is actually faster. # Counterintiutively, having two loops here is actually faster.
# The first loop can run without waiting on GPU<->CPU sync. # The first loop can run without waiting on GPU<->CPU sync.
@ -385,14 +396,18 @@ def _sample(
is_prompts, sample_indices) is_prompts, sample_indices)
if sampling_type == SamplingType.GREEDY: if sampling_type == SamplingType.GREEDY:
greedy_samples = torch.argmax(logprobs[sample_indices], dim=-1) greedy_samples = torch.argmax(logprobs[sample_indices], dim=-1)
elif sampling_type == SamplingType.RANDOM: elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
max_best_of = 1 max_best_of = 1
for seq_group, is_prompt in zip(seq_groups, is_prompts): for seq_group, is_prompt in zip(seq_groups, is_prompts):
if is_prompt: if is_prompt:
_, sampling_params = seq_group _, sampling_params = seq_group
max_best_of = max(max_best_of, sampling_params.best_of) max_best_of = max(max_best_of, sampling_params.best_of)
multinomial_samples = _multinomial(probs[sample_indices], seeded_args = {} if sampling_type == SamplingType.RANDOM else {
max_best_of) "seq_groups": seq_groups,
"generators": sampling_metadata.generators,
}
multinomial_samples[sampling_type] = _multinomial(
probs[sample_indices], max_best_of, **seeded_args)
elif sampling_type == SamplingType.BEAM: elif sampling_type == SamplingType.BEAM:
beam_search_logprobs = logprobs[sample_indices] beam_search_logprobs = logprobs[sample_indices]
else: else:
@ -407,9 +422,9 @@ def _sample(
sampling_type] sampling_type]
if sampling_type == SamplingType.GREEDY: if sampling_type == SamplingType.GREEDY:
sample_results = _greedy_sample(seq_groups, greedy_samples) sample_results = _greedy_sample(seq_groups, greedy_samples)
elif sampling_type == SamplingType.RANDOM: elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
sample_results = _random_sample(seq_groups, is_prompts, sample_results = _random_sample(seq_groups, is_prompts,
multinomial_samples) multinomial_samples[sampling_type])
elif sampling_type == SamplingType.BEAM: elif sampling_type == SamplingType.BEAM:
sample_results = _beam_search_sample(seq_groups, is_prompts, sample_results = _beam_search_sample(seq_groups, is_prompts,
sampling_metadata.seq_data, sampling_metadata.seq_data,

View File

@ -19,6 +19,7 @@ class SamplingMetadata:
prompt_lens: Lengths of prompts. prompt_lens: Lengths of prompts.
selected_token_indices: Token indices selected for sampling. selected_token_indices: Token indices selected for sampling.
categorized_sample_indices: SamplingType -> token indices to sample. categorized_sample_indices: SamplingType -> token indices to sample.
generators: List of torch.Generators to use for seeded sampling
perform_sampling: Whether to perform sampling. This option is used to perform_sampling: Whether to perform sampling. This option is used to
make the sampling only happens in the driver worker, and disable make the sampling only happens in the driver worker, and disable
sampling in other worker processes. sampling in other worker processes.
@ -31,6 +32,7 @@ class SamplingMetadata:
prompt_lens: Optional[List[int]], prompt_lens: Optional[List[int]],
selected_token_indices: torch.Tensor, selected_token_indices: torch.Tensor,
categorized_sample_indices: Optional[Dict[SamplingType, torch.Tensor]], categorized_sample_indices: Optional[Dict[SamplingType, torch.Tensor]],
generators: Optional[List[torch.Generator]] = None,
perform_sampling: bool = True, perform_sampling: bool = True,
) -> None: ) -> None:
self.seq_groups = seq_groups self.seq_groups = seq_groups
@ -38,6 +40,7 @@ class SamplingMetadata:
self.prompt_lens = prompt_lens self.prompt_lens = prompt_lens
self.selected_token_indices = selected_token_indices self.selected_token_indices = selected_token_indices
self.categorized_sample_indices = categorized_sample_indices self.categorized_sample_indices = categorized_sample_indices
self.generators = generators
self.perform_sampling = perform_sampling self.perform_sampling = perform_sampling
self.num_prompts = len(prompt_lens) if prompt_lens is not None else 0 self.num_prompts = len(prompt_lens) if prompt_lens is not None else 0

View File

@ -11,7 +11,8 @@ _SAMPLING_EPS = 1e-5
class SamplingType(IntEnum): class SamplingType(IntEnum):
GREEDY = 0 GREEDY = 0
RANDOM = 1 RANDOM = 1
BEAM = 2 RANDOM_SEED = 2
BEAM = 3
LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor] LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor]
@ -56,6 +57,7 @@ class SamplingParams:
min_p: Float that represents the minimum probability for a token to be min_p: Float that represents the minimum probability for a token to be
considered, relative to the probability of the most likely token. considered, relative to the probability of the most likely token.
Must be in [0, 1]. Set to 0 to disable this. Must be in [0, 1]. Set to 0 to disable this.
seed: Random seed to use for the generation.
use_beam_search: Whether to use beam search instead of sampling. use_beam_search: Whether to use beam search instead of sampling.
length_penalty: Float that penalizes sequences based on their length. length_penalty: Float that penalizes sequences based on their length.
Used in beam search. Used in beam search.
@ -101,6 +103,7 @@ class SamplingParams:
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = -1, top_k: int = -1,
min_p: float = 0.0, min_p: float = 0.0,
seed: Optional[int] = None,
use_beam_search: bool = False, use_beam_search: bool = False,
length_penalty: float = 1.0, length_penalty: float = 1.0,
early_stopping: Union[bool, str] = False, early_stopping: Union[bool, str] = False,
@ -124,6 +127,7 @@ class SamplingParams:
self.top_p = top_p self.top_p = top_p
self.top_k = top_k self.top_k = top_k
self.min_p = min_p self.min_p = min_p
self.seed = seed
self.use_beam_search = use_beam_search self.use_beam_search = use_beam_search
self.length_penalty = length_penalty self.length_penalty = length_penalty
self.early_stopping = early_stopping self.early_stopping = early_stopping
@ -229,6 +233,8 @@ class SamplingParams:
return SamplingType.BEAM return SamplingType.BEAM
if self.temperature < _SAMPLING_EPS: if self.temperature < _SAMPLING_EPS:
return SamplingType.GREEDY return SamplingType.GREEDY
if self.seed is not None:
return SamplingType.RANDOM_SEED
return SamplingType.RANDOM return SamplingType.RANDOM
def __repr__(self) -> str: def __repr__(self) -> str:
@ -242,6 +248,7 @@ class SamplingParams:
f"top_p={self.top_p}, " f"top_p={self.top_p}, "
f"top_k={self.top_k}, " f"top_k={self.top_k}, "
f"min_p={self.min_p}, " f"min_p={self.min_p}, "
f"seed={self.seed}, "
f"use_beam_search={self.use_beam_search}, " f"use_beam_search={self.use_beam_search}, "
f"length_penalty={self.length_penalty}, " f"length_penalty={self.length_penalty}, "
f"early_stopping={self.early_stopping}, " f"early_stopping={self.early_stopping}, "

View File

@ -248,6 +248,14 @@ class Sequence:
f"num_blocks={len(self.logical_token_blocks)})") f"num_blocks={len(self.logical_token_blocks)})")
@dataclass
class SequenceGroupState:
"""Mutable state tied to a specific sequence group"""
# torch.Generator used in seeded sampling
generator: Optional = None
class SequenceGroup: class SequenceGroup:
"""A group of sequences that are generated from the same prompt. """A group of sequences that are generated from the same prompt.
@ -280,6 +288,7 @@ class SequenceGroup:
self.lora_request = lora_request self.lora_request = lora_request
self.prefix: Optional[Prefix] = prefix self.prefix: Optional[Prefix] = prefix
self.prompt_logprobs: Optional[PromptLogprobs] = None self.prompt_logprobs: Optional[PromptLogprobs] = None
self.state = SequenceGroupState()
@property @property
def prompt(self) -> str: def prompt(self) -> str:
@ -397,6 +406,7 @@ class SequenceGroupMetadata:
sampling_params: The sampling parameters used to generate the outputs. sampling_params: The sampling parameters used to generate the outputs.
block_tables: The block tables. (Seq id -> list of physical block block_tables: The block tables. (Seq id -> list of physical block
numbers) numbers)
state: Internal state tied to this sequence group.
lora_request: LoRA request. lora_request: LoRA request.
prefix: The prefix of the prompt of the sequence group. prefix: The prefix of the prompt of the sequence group.
""" """
@ -410,6 +420,7 @@ class SequenceGroupMetadata:
block_tables: Dict[int, List[int]], block_tables: Dict[int, List[int]],
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prefix: Optional[Prefix] = None, prefix: Optional[Prefix] = None,
state: Optional[SequenceGroupState] = None,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.is_prompt = is_prompt self.is_prompt = is_prompt
@ -418,6 +429,7 @@ class SequenceGroupMetadata:
self.block_tables = block_tables self.block_tables = block_tables
self.lora_request = lora_request self.lora_request = lora_request
self.prefix = prefix self.prefix = prefix
self.state = SequenceGroupState() if state is None else state
@property @property
def lora_int_id(self) -> int: def lora_int_id(self) -> int:

View File

@ -389,6 +389,7 @@ class ModelRunner:
) -> SamplingMetadata: ) -> SamplingMetadata:
seq_groups: List[Tuple[List[int], SamplingParams]] = [] seq_groups: List[Tuple[List[int], SamplingParams]] = []
selected_token_indices: List[int] = [] selected_token_indices: List[int] = []
generators: List[torch.Generator] = []
selected_token_start_idx = 0 selected_token_start_idx = 0
categorized_sample_indices = {t: [] for t in SamplingType} categorized_sample_indices = {t: [] for t in SamplingType}
categorized_sample_indices_start_idx = 0 categorized_sample_indices_start_idx = 0
@ -419,6 +420,10 @@ class ModelRunner:
selected_token_indices.append(selected_token_start_idx + selected_token_indices.append(selected_token_start_idx +
subquery_len - 1) subquery_len - 1)
selected_token_start_idx += max_subquery_len selected_token_start_idx += max_subquery_len
if sampling_params.seed is not None:
seq_group_metadata.state.generator = torch.Generator(
device="cuda").manual_seed(sampling_params.seed)
else: else:
num_seqs = len(seq_ids) num_seqs = len(seq_ids)
selected_token_indices.extend( selected_token_indices.extend(
@ -432,6 +437,9 @@ class ModelRunner:
categorized_sample_indices_start_idx + num_seqs)) categorized_sample_indices_start_idx + num_seqs))
categorized_sample_indices_start_idx += num_seqs categorized_sample_indices_start_idx += num_seqs
if sampling_params.seed is not None:
generators.append(seq_group_metadata.state.generator)
selected_token_indices = _async_h2d(selected_token_indices, selected_token_indices = _async_h2d(selected_token_indices,
dtype=torch.long, dtype=torch.long,
target_device=self.device, target_device=self.device,
@ -454,6 +462,7 @@ class ModelRunner:
prompt_lens=prompt_lens, prompt_lens=prompt_lens,
selected_token_indices=selected_token_indices, selected_token_indices=selected_token_indices,
categorized_sample_indices=categorized_sample_indices, categorized_sample_indices=categorized_sample_indices,
generators=generators,
) )
return sampling_metadata return sampling_metadata
@ -536,6 +545,7 @@ class ModelRunner:
prompt_lens=None, prompt_lens=None,
selected_token_indices=metadata_dict["selected_token_indices"], selected_token_indices=metadata_dict["selected_token_indices"],
categorized_sample_indices=None, categorized_sample_indices=None,
generators=None,
perform_sampling=False, perform_sampling=False,
) )