mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:35:01 +08:00
Support per-request seed (#2514)
This commit is contained in:
parent
dc903e70ac
commit
7d2dcce175
@ -1,10 +1,11 @@
|
||||
import random
|
||||
from typing import Tuple
|
||||
from typing import Tuple, List
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import GenerationConfig, GenerationMixin
|
||||
from typing import Optional
|
||||
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
@ -46,6 +47,34 @@ CUDA_DEVICES = [
|
||||
]
|
||||
|
||||
|
||||
def _do_sample(
|
||||
batch_size: int,
|
||||
input_tensor: torch.Tensor,
|
||||
sampler: MockLogitsSampler,
|
||||
model_runner: ModelRunner,
|
||||
sampling_params: SamplingParams,
|
||||
):
|
||||
seq_group_metadata_list = []
|
||||
prompt_lens = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData([1, 2, 3])},
|
||||
sampling_params=sampling_params,
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens)
|
||||
return sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
sampling_metadata=sampling_metadata)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_all_greedy(seed: int, device: str):
|
||||
@ -55,25 +84,9 @@ def test_sampler_all_greedy(seed: int, device: str):
|
||||
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
|
||||
batch_size)
|
||||
|
||||
seq_group_metadata_list = []
|
||||
prompt_lens = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData([1, 2, 3])},
|
||||
sampling_params=SamplingParams(temperature=0, ),
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens)
|
||||
sampler_output = sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
sampling_metadata=sampling_metadata)
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
sampler_output = _do_sample(batch_size, input_tensor, sampler,
|
||||
model_runner, sampling_params)
|
||||
expected = torch.argmax(fake_logits, dim=-1)
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
for nth_output in sequence_output.samples:
|
||||
@ -94,28 +107,13 @@ def test_sampler_all_random(seed: int, device: str):
|
||||
for i in range(batch_size):
|
||||
fake_logits[i, i] = 1e2
|
||||
|
||||
seq_group_metadata_list = []
|
||||
prompt_lens = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData([1, 2, 3])},
|
||||
sampling_params=SamplingParams(
|
||||
temperature=1.0,
|
||||
n=random.randint(1, 10),
|
||||
),
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
n=random.randint(1, 10),
|
||||
)
|
||||
sampler_output = _do_sample(batch_size, input_tensor, sampler,
|
||||
model_runner, sampling_params)
|
||||
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens)
|
||||
sampler_output = sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
sampling_metadata=sampling_metadata)
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
for nth_output in sequence_output.samples:
|
||||
assert nth_output.output_token == i
|
||||
@ -123,6 +121,58 @@ def test_sampler_all_random(seed: int, device: str):
|
||||
del model_runner
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_all_random_seed(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
|
||||
batch_size)
|
||||
|
||||
for i in range(batch_size):
|
||||
fake_logits[i, i] = 1e2
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
n=random.randint(1, 10),
|
||||
seed=random.randint(0, 10000),
|
||||
)
|
||||
sampler_output = _do_sample(batch_size, input_tensor, sampler,
|
||||
model_runner, sampling_params)
|
||||
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
for nth_output in sequence_output.samples:
|
||||
assert nth_output.output_token == i
|
||||
|
||||
del model_runner
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_all_random_seed_deterministic(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
|
||||
batch_size)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
n=random.randint(1, 10),
|
||||
seed=random.randint(0, 10000),
|
||||
)
|
||||
first_sampler_output = _do_sample(batch_size, input_tensor, sampler,
|
||||
model_runner, sampling_params)
|
||||
|
||||
second_sampler_output = _do_sample(batch_size, input_tensor, sampler,
|
||||
model_runner, sampling_params)
|
||||
|
||||
assert first_sampler_output == second_sampler_output
|
||||
|
||||
del model_runner
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_all_beam(seed: int, device: str):
|
||||
@ -131,29 +181,13 @@ def test_sampler_all_beam(seed: int, device: str):
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
|
||||
|
||||
seq_group_metadata_list = []
|
||||
prompt_lens = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData([1, 2, 3])},
|
||||
sampling_params=SamplingParams(
|
||||
temperature=0,
|
||||
best_of=2,
|
||||
use_beam_search=True,
|
||||
),
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens)
|
||||
sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
sampling_metadata=sampling_metadata)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
best_of=2,
|
||||
use_beam_search=True,
|
||||
)
|
||||
_do_sample(batch_size, input_tensor, sampler, model_runner,
|
||||
sampling_params)
|
||||
# no assertion here as I am not sure how to determine whether
|
||||
# the outputs are expected - in other words, this just tests
|
||||
# whether there are no exceptions in the sampler
|
||||
@ -171,14 +205,15 @@ def test_sampler_mixed(seed: int, device: str):
|
||||
batch_size)
|
||||
|
||||
seq_group_metadata_list = []
|
||||
expected_tokens = []
|
||||
expected_tokens: List[Optional[List[int]]] = []
|
||||
prompt_lens = []
|
||||
for i in range(batch_size):
|
||||
n = 1
|
||||
sampling_type = random.randint(0, 2)
|
||||
expected: Optional[List[int]] = None
|
||||
sampling_type = random.randint(0, 3)
|
||||
if sampling_type == 0:
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
elif sampling_type == 1:
|
||||
expected = [torch.argmax(fake_logits[i], dim=-1).item()]
|
||||
elif sampling_type in (1, 2):
|
||||
n = random.randint(1, 10)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=random.random() + 0.1,
|
||||
@ -187,13 +222,17 @@ def test_sampler_mixed(seed: int, device: str):
|
||||
n=n,
|
||||
presence_penalty=random.randint(0, 1),
|
||||
)
|
||||
if sampling_type == 2:
|
||||
sampling_params.seed = random.randint(0, 10000)
|
||||
else:
|
||||
for idx in range(n):
|
||||
fake_logits[i, i + idx] = 1e2
|
||||
expected = list(range(i, i + n))
|
||||
else:
|
||||
sampling_params = SamplingParams(temperature=0,
|
||||
use_beam_search=True,
|
||||
best_of=2)
|
||||
for idx in range(n):
|
||||
fake_logits[i, i + idx] = 1e2
|
||||
expected_tokens.append(i + idx)
|
||||
expected_tokens.append(expected)
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
@ -204,17 +243,50 @@ def test_sampler_mixed(seed: int, device: str):
|
||||
))
|
||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens)
|
||||
sampler_output = sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
sampling_metadata=sampling_metadata)
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
if seq_group_metadata_list[i].sampling_params.use_beam_search:
|
||||
continue
|
||||
for nth_output in sequence_output.samples:
|
||||
assert nth_output.output_token in expected_tokens
|
||||
def test_sampling(model_runner: ModelRunner):
|
||||
sampling_metadata = model_runner._prepare_sample(
|
||||
seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens)
|
||||
sampler_output = sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
sampling_metadata=sampling_metadata)
|
||||
|
||||
for i, (sequence_output, metadata) in enumerate(
|
||||
zip(sampler_output, seq_group_metadata_list)):
|
||||
if metadata.sampling_params.use_beam_search:
|
||||
continue
|
||||
|
||||
if metadata.sampling_params.seed is not None \
|
||||
and expected_tokens[i] is None:
|
||||
# Record seeded random result to compare with results of second invocation
|
||||
expected_tokens[i] = [
|
||||
nth_output.output_token
|
||||
for nth_output in sequence_output.samples
|
||||
]
|
||||
continue
|
||||
|
||||
for n, nth_output in enumerate(sequence_output.samples):
|
||||
if metadata.sampling_params.temperature == 0 or metadata.sampling_params.seed is not None:
|
||||
# Ensure exact matches for greedy or random with seed
|
||||
assert nth_output.output_token == expected_tokens[i][n]
|
||||
else:
|
||||
# For non-seeded random check that one of the high-logit tokens were chosen
|
||||
assert nth_output.output_token in expected_tokens[i]
|
||||
|
||||
# Test batch
|
||||
test_sampling(model_runner)
|
||||
|
||||
# Shuffle the batch and resample
|
||||
target_index = list(range(batch_size))
|
||||
for list_to_shuffle in (target_index, seq_group_metadata_list,
|
||||
expected_tokens, prompt_lens):
|
||||
random.Random(seed).shuffle(list_to_shuffle)
|
||||
target_index = torch.tensor(target_index)
|
||||
input_tensor.data = input_tensor.index_select(0, target_index)
|
||||
fake_logits.data = fake_logits.index_select(0, target_index)
|
||||
|
||||
# This time, results of seeded random samples will be compared with the corresponding
|
||||
# sample in the pre-shuffled batch
|
||||
test_sampling(model_runner)
|
||||
|
||||
del model_runner
|
||||
|
||||
|
||||
82
tests/samplers/test_seeded_generate.py
Normal file
82
tests/samplers/test_seeded_generate.py
Normal file
@ -0,0 +1,82 @@
|
||||
"""Verify that seeded random sampling is deterministic.
|
||||
|
||||
Run `pytest tests/samplers/test_seeded_generate.py --forked`.
|
||||
"""
|
||||
import copy
|
||||
import random
|
||||
from itertools import combinations
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm import SamplingParams
|
||||
|
||||
MODEL = "facebook/opt-125m"
|
||||
RANDOM_SEEDS = list(range(5))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vllm_model(vllm_runner):
|
||||
vllm_model = vllm_runner(MODEL, dtype="half")
|
||||
yield vllm_model
|
||||
del vllm_model
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
def test_random_sample_with_seed(
|
||||
vllm_model,
|
||||
example_prompts,
|
||||
seed: int,
|
||||
) -> None:
|
||||
set_random_seed(seed)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
# Parameters to ensure sufficient randomness
|
||||
temperature=2.0,
|
||||
top_p=min(random.random() + 0.3, 1),
|
||||
top_k=random.randint(5, 20),
|
||||
n=random.randint(1, 10),
|
||||
presence_penalty=random.randint(0, 1),
|
||||
max_tokens=8,
|
||||
ignore_eos=True,
|
||||
)
|
||||
|
||||
sampling_params_seed_1 = copy.deepcopy(sampling_params)
|
||||
sampling_params_seed_1.seed = 100
|
||||
sampling_params_seed_2 = copy.deepcopy(sampling_params)
|
||||
sampling_params_seed_2.seed = 200
|
||||
|
||||
llm = vllm_model.model
|
||||
|
||||
for prompt in example_prompts:
|
||||
for params in (
|
||||
sampling_params,
|
||||
sampling_params_seed_1,
|
||||
sampling_params_seed_2,
|
||||
sampling_params,
|
||||
sampling_params_seed_1,
|
||||
sampling_params_seed_2,
|
||||
):
|
||||
llm._add_request(
|
||||
prompt=prompt,
|
||||
prompt_token_ids=None,
|
||||
sampling_params=params,
|
||||
)
|
||||
|
||||
results = llm._run_engine(use_tqdm=False)
|
||||
all_outputs = [[out.token_ids for out in output.outputs]
|
||||
for output in results]
|
||||
|
||||
for i in range(0, len(example_prompts), 6):
|
||||
outputs = all_outputs[i:i + 6]
|
||||
|
||||
# verify all non-seeded requests differ
|
||||
for output_a, output_b in combinations(
|
||||
(outputs[0], outputs[1], outputs[2], outputs[3]),
|
||||
2,
|
||||
):
|
||||
assert output_a != output_b
|
||||
|
||||
# verify requests with the same seed match
|
||||
assert outputs[1] == outputs[4]
|
||||
assert outputs[2] == outputs[5]
|
||||
@ -387,6 +387,7 @@ class Scheduler:
|
||||
block_tables=block_tables,
|
||||
lora_request=seq_group.lora_request,
|
||||
prefix=seq_group.prefix,
|
||||
state=seq_group.state,
|
||||
)
|
||||
seq_group_metadata_list.append(seq_group_metadata)
|
||||
return seq_group_metadata_list, scheduler_outputs
|
||||
|
||||
@ -173,7 +173,6 @@ class EngineArgs:
|
||||
default=EngineArgs.block_size,
|
||||
choices=[8, 16, 32],
|
||||
help='token block size')
|
||||
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
|
||||
parser.add_argument('--seed',
|
||||
type=int,
|
||||
default=EngineArgs.seed,
|
||||
|
||||
@ -60,6 +60,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
top_p: Optional[float] = 1.0
|
||||
n: Optional[int] = 1
|
||||
max_tokens: Optional[int] = None
|
||||
seed: Optional[int] = None
|
||||
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||
stream: Optional[bool] = False
|
||||
presence_penalty: Optional[float] = 0.0
|
||||
@ -90,6 +91,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p,
|
||||
min_p=self.min_p,
|
||||
seed=self.seed,
|
||||
stop=self.stop,
|
||||
stop_token_ids=self.stop_token_ids,
|
||||
max_tokens=self.max_tokens,
|
||||
@ -117,6 +119,7 @@ class CompletionRequest(BaseModel):
|
||||
logprobs: Optional[int] = None
|
||||
echo: Optional[bool] = False
|
||||
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||
seed: Optional[int] = None
|
||||
presence_penalty: Optional[float] = 0.0
|
||||
frequency_penalty: Optional[float] = 0.0
|
||||
best_of: Optional[int] = None
|
||||
@ -147,6 +150,7 @@ class CompletionRequest(BaseModel):
|
||||
top_p=self.top_p,
|
||||
top_k=self.top_k,
|
||||
min_p=self.min_p,
|
||||
seed=self.seed,
|
||||
stop=self.stop,
|
||||
stop_token_ids=self.stop_token_ids,
|
||||
ignore_eos=self.ignore_eos,
|
||||
|
||||
@ -342,7 +342,9 @@ def _beam_search_sample(
|
||||
def _multinomial(
|
||||
probs: torch.Tensor,
|
||||
num_samples: int,
|
||||
):
|
||||
seq_groups: Optional[List[Tuple[List[int], SamplingParams]]] = None,
|
||||
generators: Optional[List[torch.Generator]] = None,
|
||||
) -> torch.Tensor:
|
||||
if num_samples > 1:
|
||||
# This is equivalent to torch.repeat_interleaved (which also
|
||||
# forces a GPU<->CPU sync).
|
||||
@ -352,7 +354,15 @@ def _multinomial(
|
||||
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
|
||||
probs.shape[1]).contiguous().view(
|
||||
-1, probs.shape[1])
|
||||
q = torch.empty_like(probs).exponential_(1)
|
||||
q = torch.empty_like(probs)
|
||||
if seq_groups is None:
|
||||
q.exponential_()
|
||||
else:
|
||||
sample_idx = 0
|
||||
for (seq_ids, _), generator in zip(seq_groups, generators):
|
||||
next_sample_idx = sample_idx + len(seq_ids) * num_samples
|
||||
q[sample_idx:next_sample_idx].exponential_(generator=generator)
|
||||
sample_idx = next_sample_idx
|
||||
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
|
||||
|
||||
|
||||
@ -370,6 +380,7 @@ def _sample(
|
||||
|
||||
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
|
||||
sample_metadata = {}
|
||||
multinomial_samples = {}
|
||||
|
||||
# Counterintiutively, having two loops here is actually faster.
|
||||
# The first loop can run without waiting on GPU<->CPU sync.
|
||||
@ -385,14 +396,18 @@ def _sample(
|
||||
is_prompts, sample_indices)
|
||||
if sampling_type == SamplingType.GREEDY:
|
||||
greedy_samples = torch.argmax(logprobs[sample_indices], dim=-1)
|
||||
elif sampling_type == SamplingType.RANDOM:
|
||||
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
||||
max_best_of = 1
|
||||
for seq_group, is_prompt in zip(seq_groups, is_prompts):
|
||||
if is_prompt:
|
||||
_, sampling_params = seq_group
|
||||
max_best_of = max(max_best_of, sampling_params.best_of)
|
||||
multinomial_samples = _multinomial(probs[sample_indices],
|
||||
max_best_of)
|
||||
seeded_args = {} if sampling_type == SamplingType.RANDOM else {
|
||||
"seq_groups": seq_groups,
|
||||
"generators": sampling_metadata.generators,
|
||||
}
|
||||
multinomial_samples[sampling_type] = _multinomial(
|
||||
probs[sample_indices], max_best_of, **seeded_args)
|
||||
elif sampling_type == SamplingType.BEAM:
|
||||
beam_search_logprobs = logprobs[sample_indices]
|
||||
else:
|
||||
@ -407,9 +422,9 @@ def _sample(
|
||||
sampling_type]
|
||||
if sampling_type == SamplingType.GREEDY:
|
||||
sample_results = _greedy_sample(seq_groups, greedy_samples)
|
||||
elif sampling_type == SamplingType.RANDOM:
|
||||
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
||||
sample_results = _random_sample(seq_groups, is_prompts,
|
||||
multinomial_samples)
|
||||
multinomial_samples[sampling_type])
|
||||
elif sampling_type == SamplingType.BEAM:
|
||||
sample_results = _beam_search_sample(seq_groups, is_prompts,
|
||||
sampling_metadata.seq_data,
|
||||
|
||||
@ -19,6 +19,7 @@ class SamplingMetadata:
|
||||
prompt_lens: Lengths of prompts.
|
||||
selected_token_indices: Token indices selected for sampling.
|
||||
categorized_sample_indices: SamplingType -> token indices to sample.
|
||||
generators: List of torch.Generators to use for seeded sampling
|
||||
perform_sampling: Whether to perform sampling. This option is used to
|
||||
make the sampling only happens in the driver worker, and disable
|
||||
sampling in other worker processes.
|
||||
@ -31,6 +32,7 @@ class SamplingMetadata:
|
||||
prompt_lens: Optional[List[int]],
|
||||
selected_token_indices: torch.Tensor,
|
||||
categorized_sample_indices: Optional[Dict[SamplingType, torch.Tensor]],
|
||||
generators: Optional[List[torch.Generator]] = None,
|
||||
perform_sampling: bool = True,
|
||||
) -> None:
|
||||
self.seq_groups = seq_groups
|
||||
@ -38,6 +40,7 @@ class SamplingMetadata:
|
||||
self.prompt_lens = prompt_lens
|
||||
self.selected_token_indices = selected_token_indices
|
||||
self.categorized_sample_indices = categorized_sample_indices
|
||||
self.generators = generators
|
||||
self.perform_sampling = perform_sampling
|
||||
|
||||
self.num_prompts = len(prompt_lens) if prompt_lens is not None else 0
|
||||
|
||||
@ -11,7 +11,8 @@ _SAMPLING_EPS = 1e-5
|
||||
class SamplingType(IntEnum):
|
||||
GREEDY = 0
|
||||
RANDOM = 1
|
||||
BEAM = 2
|
||||
RANDOM_SEED = 2
|
||||
BEAM = 3
|
||||
|
||||
|
||||
LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor]
|
||||
@ -56,6 +57,7 @@ class SamplingParams:
|
||||
min_p: Float that represents the minimum probability for a token to be
|
||||
considered, relative to the probability of the most likely token.
|
||||
Must be in [0, 1]. Set to 0 to disable this.
|
||||
seed: Random seed to use for the generation.
|
||||
use_beam_search: Whether to use beam search instead of sampling.
|
||||
length_penalty: Float that penalizes sequences based on their length.
|
||||
Used in beam search.
|
||||
@ -101,6 +103,7 @@ class SamplingParams:
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
min_p: float = 0.0,
|
||||
seed: Optional[int] = None,
|
||||
use_beam_search: bool = False,
|
||||
length_penalty: float = 1.0,
|
||||
early_stopping: Union[bool, str] = False,
|
||||
@ -124,6 +127,7 @@ class SamplingParams:
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
self.min_p = min_p
|
||||
self.seed = seed
|
||||
self.use_beam_search = use_beam_search
|
||||
self.length_penalty = length_penalty
|
||||
self.early_stopping = early_stopping
|
||||
@ -229,6 +233,8 @@ class SamplingParams:
|
||||
return SamplingType.BEAM
|
||||
if self.temperature < _SAMPLING_EPS:
|
||||
return SamplingType.GREEDY
|
||||
if self.seed is not None:
|
||||
return SamplingType.RANDOM_SEED
|
||||
return SamplingType.RANDOM
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@ -242,6 +248,7 @@ class SamplingParams:
|
||||
f"top_p={self.top_p}, "
|
||||
f"top_k={self.top_k}, "
|
||||
f"min_p={self.min_p}, "
|
||||
f"seed={self.seed}, "
|
||||
f"use_beam_search={self.use_beam_search}, "
|
||||
f"length_penalty={self.length_penalty}, "
|
||||
f"early_stopping={self.early_stopping}, "
|
||||
|
||||
@ -248,6 +248,14 @@ class Sequence:
|
||||
f"num_blocks={len(self.logical_token_blocks)})")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SequenceGroupState:
|
||||
"""Mutable state tied to a specific sequence group"""
|
||||
|
||||
# torch.Generator used in seeded sampling
|
||||
generator: Optional = None
|
||||
|
||||
|
||||
class SequenceGroup:
|
||||
"""A group of sequences that are generated from the same prompt.
|
||||
|
||||
@ -280,6 +288,7 @@ class SequenceGroup:
|
||||
self.lora_request = lora_request
|
||||
self.prefix: Optional[Prefix] = prefix
|
||||
self.prompt_logprobs: Optional[PromptLogprobs] = None
|
||||
self.state = SequenceGroupState()
|
||||
|
||||
@property
|
||||
def prompt(self) -> str:
|
||||
@ -397,6 +406,7 @@ class SequenceGroupMetadata:
|
||||
sampling_params: The sampling parameters used to generate the outputs.
|
||||
block_tables: The block tables. (Seq id -> list of physical block
|
||||
numbers)
|
||||
state: Internal state tied to this sequence group.
|
||||
lora_request: LoRA request.
|
||||
prefix: The prefix of the prompt of the sequence group.
|
||||
"""
|
||||
@ -410,6 +420,7 @@ class SequenceGroupMetadata:
|
||||
block_tables: Dict[int, List[int]],
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prefix: Optional[Prefix] = None,
|
||||
state: Optional[SequenceGroupState] = None,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.is_prompt = is_prompt
|
||||
@ -418,6 +429,7 @@ class SequenceGroupMetadata:
|
||||
self.block_tables = block_tables
|
||||
self.lora_request = lora_request
|
||||
self.prefix = prefix
|
||||
self.state = SequenceGroupState() if state is None else state
|
||||
|
||||
@property
|
||||
def lora_int_id(self) -> int:
|
||||
|
||||
@ -389,6 +389,7 @@ class ModelRunner:
|
||||
) -> SamplingMetadata:
|
||||
seq_groups: List[Tuple[List[int], SamplingParams]] = []
|
||||
selected_token_indices: List[int] = []
|
||||
generators: List[torch.Generator] = []
|
||||
selected_token_start_idx = 0
|
||||
categorized_sample_indices = {t: [] for t in SamplingType}
|
||||
categorized_sample_indices_start_idx = 0
|
||||
@ -419,6 +420,10 @@ class ModelRunner:
|
||||
selected_token_indices.append(selected_token_start_idx +
|
||||
subquery_len - 1)
|
||||
selected_token_start_idx += max_subquery_len
|
||||
|
||||
if sampling_params.seed is not None:
|
||||
seq_group_metadata.state.generator = torch.Generator(
|
||||
device="cuda").manual_seed(sampling_params.seed)
|
||||
else:
|
||||
num_seqs = len(seq_ids)
|
||||
selected_token_indices.extend(
|
||||
@ -432,6 +437,9 @@ class ModelRunner:
|
||||
categorized_sample_indices_start_idx + num_seqs))
|
||||
categorized_sample_indices_start_idx += num_seqs
|
||||
|
||||
if sampling_params.seed is not None:
|
||||
generators.append(seq_group_metadata.state.generator)
|
||||
|
||||
selected_token_indices = _async_h2d(selected_token_indices,
|
||||
dtype=torch.long,
|
||||
target_device=self.device,
|
||||
@ -454,6 +462,7 @@ class ModelRunner:
|
||||
prompt_lens=prompt_lens,
|
||||
selected_token_indices=selected_token_indices,
|
||||
categorized_sample_indices=categorized_sample_indices,
|
||||
generators=generators,
|
||||
)
|
||||
return sampling_metadata
|
||||
|
||||
@ -536,6 +545,7 @@ class ModelRunner:
|
||||
prompt_lens=None,
|
||||
selected_token_indices=metadata_dict["selected_token_indices"],
|
||||
categorized_sample_indices=None,
|
||||
generators=None,
|
||||
perform_sampling=False,
|
||||
)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user