Support per-request seed (#2514)

This commit is contained in:
Nick Hill 2024-02-21 11:47:00 -08:00 committed by GitHub
parent dc903e70ac
commit 7d2dcce175
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 296 additions and 91 deletions

View File

@ -1,10 +1,11 @@
import random
from typing import Tuple
from typing import Tuple, List
from unittest.mock import patch
import pytest
import torch
from transformers import GenerationConfig, GenerationMixin
from typing import Optional
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.utils import set_random_seed
@ -46,6 +47,34 @@ CUDA_DEVICES = [
]
def _do_sample(
batch_size: int,
input_tensor: torch.Tensor,
sampler: MockLogitsSampler,
model_runner: ModelRunner,
sampling_params: SamplingParams,
):
seq_group_metadata_list = []
prompt_lens = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=sampling_params,
block_tables={0: [1]},
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
return sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_greedy(seed: int, device: str):
@ -55,25 +84,9 @@ def test_sampler_all_greedy(seed: int, device: str):
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
batch_size)
seq_group_metadata_list = []
prompt_lens = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(temperature=0, ),
block_tables={0: [1]},
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
sampling_params = SamplingParams(temperature=0)
sampler_output = _do_sample(batch_size, input_tensor, sampler,
model_runner, sampling_params)
expected = torch.argmax(fake_logits, dim=-1)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
@ -94,28 +107,13 @@ def test_sampler_all_random(seed: int, device: str):
for i in range(batch_size):
fake_logits[i, i] = 1e2
seq_group_metadata_list = []
prompt_lens = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(
temperature=1.0,
n=random.randint(1, 10),
),
block_tables={0: [1]},
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_params = SamplingParams(
temperature=1.0,
n=random.randint(1, 10),
)
sampler_output = _do_sample(batch_size, input_tensor, sampler,
model_runner, sampling_params)
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
assert nth_output.output_token == i
@ -123,6 +121,58 @@ def test_sampler_all_random(seed: int, device: str):
del model_runner
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random_seed(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
batch_size)
for i in range(batch_size):
fake_logits[i, i] = 1e2
sampling_params = SamplingParams(
temperature=1.0,
n=random.randint(1, 10),
seed=random.randint(0, 10000),
)
sampler_output = _do_sample(batch_size, input_tensor, sampler,
model_runner, sampling_params)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
assert nth_output.output_token == i
del model_runner
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random_seed_deterministic(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
batch_size)
sampling_params = SamplingParams(
temperature=1.0,
n=random.randint(1, 10),
seed=random.randint(0, 10000),
)
first_sampler_output = _do_sample(batch_size, input_tensor, sampler,
model_runner, sampling_params)
second_sampler_output = _do_sample(batch_size, input_tensor, sampler,
model_runner, sampling_params)
assert first_sampler_output == second_sampler_output
del model_runner
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_beam(seed: int, device: str):
@ -131,29 +181,13 @@ def test_sampler_all_beam(seed: int, device: str):
batch_size = random.randint(1, 256)
input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
seq_group_metadata_list = []
prompt_lens = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(
temperature=0,
best_of=2,
use_beam_search=True,
),
block_tables={0: [1]},
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
sampling_params = SamplingParams(
temperature=0,
best_of=2,
use_beam_search=True,
)
_do_sample(batch_size, input_tensor, sampler, model_runner,
sampling_params)
# no assertion here as I am not sure how to determine whether
# the outputs are expected - in other words, this just tests
# whether there are no exceptions in the sampler
@ -171,14 +205,15 @@ def test_sampler_mixed(seed: int, device: str):
batch_size)
seq_group_metadata_list = []
expected_tokens = []
expected_tokens: List[Optional[List[int]]] = []
prompt_lens = []
for i in range(batch_size):
n = 1
sampling_type = random.randint(0, 2)
expected: Optional[List[int]] = None
sampling_type = random.randint(0, 3)
if sampling_type == 0:
sampling_params = SamplingParams(temperature=0)
elif sampling_type == 1:
expected = [torch.argmax(fake_logits[i], dim=-1).item()]
elif sampling_type in (1, 2):
n = random.randint(1, 10)
sampling_params = SamplingParams(
temperature=random.random() + 0.1,
@ -187,13 +222,17 @@ def test_sampler_mixed(seed: int, device: str):
n=n,
presence_penalty=random.randint(0, 1),
)
if sampling_type == 2:
sampling_params.seed = random.randint(0, 10000)
else:
for idx in range(n):
fake_logits[i, i + idx] = 1e2
expected = list(range(i, i + n))
else:
sampling_params = SamplingParams(temperature=0,
use_beam_search=True,
best_of=2)
for idx in range(n):
fake_logits[i, i + idx] = 1e2
expected_tokens.append(i + idx)
expected_tokens.append(expected)
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
@ -204,17 +243,50 @@ def test_sampler_mixed(seed: int, device: str):
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
for i, sequence_output in enumerate(sampler_output):
if seq_group_metadata_list[i].sampling_params.use_beam_search:
continue
for nth_output in sequence_output.samples:
assert nth_output.output_token in expected_tokens
def test_sampling(model_runner: ModelRunner):
sampling_metadata = model_runner._prepare_sample(
seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
for i, (sequence_output, metadata) in enumerate(
zip(sampler_output, seq_group_metadata_list)):
if metadata.sampling_params.use_beam_search:
continue
if metadata.sampling_params.seed is not None \
and expected_tokens[i] is None:
# Record seeded random result to compare with results of second invocation
expected_tokens[i] = [
nth_output.output_token
for nth_output in sequence_output.samples
]
continue
for n, nth_output in enumerate(sequence_output.samples):
if metadata.sampling_params.temperature == 0 or metadata.sampling_params.seed is not None:
# Ensure exact matches for greedy or random with seed
assert nth_output.output_token == expected_tokens[i][n]
else:
# For non-seeded random check that one of the high-logit tokens were chosen
assert nth_output.output_token in expected_tokens[i]
# Test batch
test_sampling(model_runner)
# Shuffle the batch and resample
target_index = list(range(batch_size))
for list_to_shuffle in (target_index, seq_group_metadata_list,
expected_tokens, prompt_lens):
random.Random(seed).shuffle(list_to_shuffle)
target_index = torch.tensor(target_index)
input_tensor.data = input_tensor.index_select(0, target_index)
fake_logits.data = fake_logits.index_select(0, target_index)
# This time, results of seeded random samples will be compared with the corresponding
# sample in the pre-shuffled batch
test_sampling(model_runner)
del model_runner

View File

@ -0,0 +1,82 @@
"""Verify that seeded random sampling is deterministic.
Run `pytest tests/samplers/test_seeded_generate.py --forked`.
"""
import copy
import random
from itertools import combinations
import pytest
from vllm.model_executor.utils import set_random_seed
from vllm import SamplingParams
MODEL = "facebook/opt-125m"
RANDOM_SEEDS = list(range(5))
@pytest.fixture
def vllm_model(vllm_runner):
vllm_model = vllm_runner(MODEL, dtype="half")
yield vllm_model
del vllm_model
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_random_sample_with_seed(
vllm_model,
example_prompts,
seed: int,
) -> None:
set_random_seed(seed)
sampling_params = SamplingParams(
# Parameters to ensure sufficient randomness
temperature=2.0,
top_p=min(random.random() + 0.3, 1),
top_k=random.randint(5, 20),
n=random.randint(1, 10),
presence_penalty=random.randint(0, 1),
max_tokens=8,
ignore_eos=True,
)
sampling_params_seed_1 = copy.deepcopy(sampling_params)
sampling_params_seed_1.seed = 100
sampling_params_seed_2 = copy.deepcopy(sampling_params)
sampling_params_seed_2.seed = 200
llm = vllm_model.model
for prompt in example_prompts:
for params in (
sampling_params,
sampling_params_seed_1,
sampling_params_seed_2,
sampling_params,
sampling_params_seed_1,
sampling_params_seed_2,
):
llm._add_request(
prompt=prompt,
prompt_token_ids=None,
sampling_params=params,
)
results = llm._run_engine(use_tqdm=False)
all_outputs = [[out.token_ids for out in output.outputs]
for output in results]
for i in range(0, len(example_prompts), 6):
outputs = all_outputs[i:i + 6]
# verify all non-seeded requests differ
for output_a, output_b in combinations(
(outputs[0], outputs[1], outputs[2], outputs[3]),
2,
):
assert output_a != output_b
# verify requests with the same seed match
assert outputs[1] == outputs[4]
assert outputs[2] == outputs[5]

View File

@ -387,6 +387,7 @@ class Scheduler:
block_tables=block_tables,
lora_request=seq_group.lora_request,
prefix=seq_group.prefix,
state=seq_group.state,
)
seq_group_metadata_list.append(seq_group_metadata)
return seq_group_metadata_list, scheduler_outputs

View File

@ -173,7 +173,6 @@ class EngineArgs:
default=EngineArgs.block_size,
choices=[8, 16, 32],
help='token block size')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed',
type=int,
default=EngineArgs.seed,

View File

@ -60,6 +60,7 @@ class ChatCompletionRequest(BaseModel):
top_p: Optional[float] = 1.0
n: Optional[int] = 1
max_tokens: Optional[int] = None
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
presence_penalty: Optional[float] = 0.0
@ -90,6 +91,7 @@ class ChatCompletionRequest(BaseModel):
temperature=self.temperature,
top_p=self.top_p,
min_p=self.min_p,
seed=self.seed,
stop=self.stop,
stop_token_ids=self.stop_token_ids,
max_tokens=self.max_tokens,
@ -117,6 +119,7 @@ class CompletionRequest(BaseModel):
logprobs: Optional[int] = None
echo: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
seed: Optional[int] = None
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
best_of: Optional[int] = None
@ -147,6 +150,7 @@ class CompletionRequest(BaseModel):
top_p=self.top_p,
top_k=self.top_k,
min_p=self.min_p,
seed=self.seed,
stop=self.stop,
stop_token_ids=self.stop_token_ids,
ignore_eos=self.ignore_eos,

View File

@ -342,7 +342,9 @@ def _beam_search_sample(
def _multinomial(
probs: torch.Tensor,
num_samples: int,
):
seq_groups: Optional[List[Tuple[List[int], SamplingParams]]] = None,
generators: Optional[List[torch.Generator]] = None,
) -> torch.Tensor:
if num_samples > 1:
# This is equivalent to torch.repeat_interleaved (which also
# forces a GPU<->CPU sync).
@ -352,7 +354,15 @@ def _multinomial(
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
probs.shape[1]).contiguous().view(
-1, probs.shape[1])
q = torch.empty_like(probs).exponential_(1)
q = torch.empty_like(probs)
if seq_groups is None:
q.exponential_()
else:
sample_idx = 0
for (seq_ids, _), generator in zip(seq_groups, generators):
next_sample_idx = sample_idx + len(seq_ids) * num_samples
q[sample_idx:next_sample_idx].exponential_(generator=generator)
sample_idx = next_sample_idx
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
@ -370,6 +380,7 @@ def _sample(
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
sample_metadata = {}
multinomial_samples = {}
# Counterintiutively, having two loops here is actually faster.
# The first loop can run without waiting on GPU<->CPU sync.
@ -385,14 +396,18 @@ def _sample(
is_prompts, sample_indices)
if sampling_type == SamplingType.GREEDY:
greedy_samples = torch.argmax(logprobs[sample_indices], dim=-1)
elif sampling_type == SamplingType.RANDOM:
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
max_best_of = 1
for seq_group, is_prompt in zip(seq_groups, is_prompts):
if is_prompt:
_, sampling_params = seq_group
max_best_of = max(max_best_of, sampling_params.best_of)
multinomial_samples = _multinomial(probs[sample_indices],
max_best_of)
seeded_args = {} if sampling_type == SamplingType.RANDOM else {
"seq_groups": seq_groups,
"generators": sampling_metadata.generators,
}
multinomial_samples[sampling_type] = _multinomial(
probs[sample_indices], max_best_of, **seeded_args)
elif sampling_type == SamplingType.BEAM:
beam_search_logprobs = logprobs[sample_indices]
else:
@ -407,9 +422,9 @@ def _sample(
sampling_type]
if sampling_type == SamplingType.GREEDY:
sample_results = _greedy_sample(seq_groups, greedy_samples)
elif sampling_type == SamplingType.RANDOM:
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
sample_results = _random_sample(seq_groups, is_prompts,
multinomial_samples)
multinomial_samples[sampling_type])
elif sampling_type == SamplingType.BEAM:
sample_results = _beam_search_sample(seq_groups, is_prompts,
sampling_metadata.seq_data,

View File

@ -19,6 +19,7 @@ class SamplingMetadata:
prompt_lens: Lengths of prompts.
selected_token_indices: Token indices selected for sampling.
categorized_sample_indices: SamplingType -> token indices to sample.
generators: List of torch.Generators to use for seeded sampling
perform_sampling: Whether to perform sampling. This option is used to
make the sampling only happens in the driver worker, and disable
sampling in other worker processes.
@ -31,6 +32,7 @@ class SamplingMetadata:
prompt_lens: Optional[List[int]],
selected_token_indices: torch.Tensor,
categorized_sample_indices: Optional[Dict[SamplingType, torch.Tensor]],
generators: Optional[List[torch.Generator]] = None,
perform_sampling: bool = True,
) -> None:
self.seq_groups = seq_groups
@ -38,6 +40,7 @@ class SamplingMetadata:
self.prompt_lens = prompt_lens
self.selected_token_indices = selected_token_indices
self.categorized_sample_indices = categorized_sample_indices
self.generators = generators
self.perform_sampling = perform_sampling
self.num_prompts = len(prompt_lens) if prompt_lens is not None else 0

View File

@ -11,7 +11,8 @@ _SAMPLING_EPS = 1e-5
class SamplingType(IntEnum):
GREEDY = 0
RANDOM = 1
BEAM = 2
RANDOM_SEED = 2
BEAM = 3
LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor]
@ -56,6 +57,7 @@ class SamplingParams:
min_p: Float that represents the minimum probability for a token to be
considered, relative to the probability of the most likely token.
Must be in [0, 1]. Set to 0 to disable this.
seed: Random seed to use for the generation.
use_beam_search: Whether to use beam search instead of sampling.
length_penalty: Float that penalizes sequences based on their length.
Used in beam search.
@ -101,6 +103,7 @@ class SamplingParams:
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
seed: Optional[int] = None,
use_beam_search: bool = False,
length_penalty: float = 1.0,
early_stopping: Union[bool, str] = False,
@ -124,6 +127,7 @@ class SamplingParams:
self.top_p = top_p
self.top_k = top_k
self.min_p = min_p
self.seed = seed
self.use_beam_search = use_beam_search
self.length_penalty = length_penalty
self.early_stopping = early_stopping
@ -229,6 +233,8 @@ class SamplingParams:
return SamplingType.BEAM
if self.temperature < _SAMPLING_EPS:
return SamplingType.GREEDY
if self.seed is not None:
return SamplingType.RANDOM_SEED
return SamplingType.RANDOM
def __repr__(self) -> str:
@ -242,6 +248,7 @@ class SamplingParams:
f"top_p={self.top_p}, "
f"top_k={self.top_k}, "
f"min_p={self.min_p}, "
f"seed={self.seed}, "
f"use_beam_search={self.use_beam_search}, "
f"length_penalty={self.length_penalty}, "
f"early_stopping={self.early_stopping}, "

View File

@ -248,6 +248,14 @@ class Sequence:
f"num_blocks={len(self.logical_token_blocks)})")
@dataclass
class SequenceGroupState:
"""Mutable state tied to a specific sequence group"""
# torch.Generator used in seeded sampling
generator: Optional = None
class SequenceGroup:
"""A group of sequences that are generated from the same prompt.
@ -280,6 +288,7 @@ class SequenceGroup:
self.lora_request = lora_request
self.prefix: Optional[Prefix] = prefix
self.prompt_logprobs: Optional[PromptLogprobs] = None
self.state = SequenceGroupState()
@property
def prompt(self) -> str:
@ -397,6 +406,7 @@ class SequenceGroupMetadata:
sampling_params: The sampling parameters used to generate the outputs.
block_tables: The block tables. (Seq id -> list of physical block
numbers)
state: Internal state tied to this sequence group.
lora_request: LoRA request.
prefix: The prefix of the prompt of the sequence group.
"""
@ -410,6 +420,7 @@ class SequenceGroupMetadata:
block_tables: Dict[int, List[int]],
lora_request: Optional[LoRARequest] = None,
prefix: Optional[Prefix] = None,
state: Optional[SequenceGroupState] = None,
) -> None:
self.request_id = request_id
self.is_prompt = is_prompt
@ -418,6 +429,7 @@ class SequenceGroupMetadata:
self.block_tables = block_tables
self.lora_request = lora_request
self.prefix = prefix
self.state = SequenceGroupState() if state is None else state
@property
def lora_int_id(self) -> int:

View File

@ -389,6 +389,7 @@ class ModelRunner:
) -> SamplingMetadata:
seq_groups: List[Tuple[List[int], SamplingParams]] = []
selected_token_indices: List[int] = []
generators: List[torch.Generator] = []
selected_token_start_idx = 0
categorized_sample_indices = {t: [] for t in SamplingType}
categorized_sample_indices_start_idx = 0
@ -419,6 +420,10 @@ class ModelRunner:
selected_token_indices.append(selected_token_start_idx +
subquery_len - 1)
selected_token_start_idx += max_subquery_len
if sampling_params.seed is not None:
seq_group_metadata.state.generator = torch.Generator(
device="cuda").manual_seed(sampling_params.seed)
else:
num_seqs = len(seq_ids)
selected_token_indices.extend(
@ -432,6 +437,9 @@ class ModelRunner:
categorized_sample_indices_start_idx + num_seqs))
categorized_sample_indices_start_idx += num_seqs
if sampling_params.seed is not None:
generators.append(seq_group_metadata.state.generator)
selected_token_indices = _async_h2d(selected_token_indices,
dtype=torch.long,
target_device=self.device,
@ -454,6 +462,7 @@ class ModelRunner:
prompt_lens=prompt_lens,
selected_token_indices=selected_token_indices,
categorized_sample_indices=categorized_sample_indices,
generators=generators,
)
return sampling_metadata
@ -536,6 +545,7 @@ class ModelRunner:
prompt_lens=None,
selected_token_indices=metadata_dict["selected_token_indices"],
categorized_sample_indices=None,
generators=None,
perform_sampling=False,
)