mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 07:34:59 +08:00
Support per-request seed (#2514)
This commit is contained in:
parent
dc903e70ac
commit
7d2dcce175
@ -1,10 +1,11 @@
|
|||||||
import random
|
import random
|
||||||
from typing import Tuple
|
from typing import Tuple, List
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from transformers import GenerationConfig, GenerationMixin
|
from transformers import GenerationConfig, GenerationMixin
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.utils import set_random_seed
|
from vllm.model_executor.utils import set_random_seed
|
||||||
@ -46,6 +47,34 @@ CUDA_DEVICES = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _do_sample(
|
||||||
|
batch_size: int,
|
||||||
|
input_tensor: torch.Tensor,
|
||||||
|
sampler: MockLogitsSampler,
|
||||||
|
model_runner: ModelRunner,
|
||||||
|
sampling_params: SamplingParams,
|
||||||
|
):
|
||||||
|
seq_group_metadata_list = []
|
||||||
|
prompt_lens = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
seq_group_metadata_list.append(
|
||||||
|
SequenceGroupMetadata(
|
||||||
|
request_id=f"test_{i}",
|
||||||
|
is_prompt=True,
|
||||||
|
seq_data={0: SequenceData([1, 2, 3])},
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
block_tables={0: [1]},
|
||||||
|
))
|
||||||
|
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||||
|
|
||||||
|
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||||
|
prompt_lens,
|
||||||
|
subquery_lens=prompt_lens)
|
||||||
|
return sampler(embedding=None,
|
||||||
|
hidden_states=input_tensor,
|
||||||
|
sampling_metadata=sampling_metadata)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
def test_sampler_all_greedy(seed: int, device: str):
|
def test_sampler_all_greedy(seed: int, device: str):
|
||||||
@ -55,25 +84,9 @@ def test_sampler_all_greedy(seed: int, device: str):
|
|||||||
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
|
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
|
||||||
batch_size)
|
batch_size)
|
||||||
|
|
||||||
seq_group_metadata_list = []
|
sampling_params = SamplingParams(temperature=0)
|
||||||
prompt_lens = []
|
sampler_output = _do_sample(batch_size, input_tensor, sampler,
|
||||||
for i in range(batch_size):
|
model_runner, sampling_params)
|
||||||
seq_group_metadata_list.append(
|
|
||||||
SequenceGroupMetadata(
|
|
||||||
request_id=f"test_{i}",
|
|
||||||
is_prompt=True,
|
|
||||||
seq_data={0: SequenceData([1, 2, 3])},
|
|
||||||
sampling_params=SamplingParams(temperature=0, ),
|
|
||||||
block_tables={0: [1]},
|
|
||||||
))
|
|
||||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
|
||||||
|
|
||||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
|
||||||
prompt_lens,
|
|
||||||
subquery_lens=prompt_lens)
|
|
||||||
sampler_output = sampler(embedding=None,
|
|
||||||
hidden_states=input_tensor,
|
|
||||||
sampling_metadata=sampling_metadata)
|
|
||||||
expected = torch.argmax(fake_logits, dim=-1)
|
expected = torch.argmax(fake_logits, dim=-1)
|
||||||
for i, sequence_output in enumerate(sampler_output):
|
for i, sequence_output in enumerate(sampler_output):
|
||||||
for nth_output in sequence_output.samples:
|
for nth_output in sequence_output.samples:
|
||||||
@ -94,28 +107,13 @@ def test_sampler_all_random(seed: int, device: str):
|
|||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
fake_logits[i, i] = 1e2
|
fake_logits[i, i] = 1e2
|
||||||
|
|
||||||
seq_group_metadata_list = []
|
sampling_params = SamplingParams(
|
||||||
prompt_lens = []
|
temperature=1.0,
|
||||||
for i in range(batch_size):
|
n=random.randint(1, 10),
|
||||||
seq_group_metadata_list.append(
|
)
|
||||||
SequenceGroupMetadata(
|
sampler_output = _do_sample(batch_size, input_tensor, sampler,
|
||||||
request_id=f"test_{i}",
|
model_runner, sampling_params)
|
||||||
is_prompt=True,
|
|
||||||
seq_data={0: SequenceData([1, 2, 3])},
|
|
||||||
sampling_params=SamplingParams(
|
|
||||||
temperature=1.0,
|
|
||||||
n=random.randint(1, 10),
|
|
||||||
),
|
|
||||||
block_tables={0: [1]},
|
|
||||||
))
|
|
||||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
|
||||||
|
|
||||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
|
||||||
prompt_lens,
|
|
||||||
subquery_lens=prompt_lens)
|
|
||||||
sampler_output = sampler(embedding=None,
|
|
||||||
hidden_states=input_tensor,
|
|
||||||
sampling_metadata=sampling_metadata)
|
|
||||||
for i, sequence_output in enumerate(sampler_output):
|
for i, sequence_output in enumerate(sampler_output):
|
||||||
for nth_output in sequence_output.samples:
|
for nth_output in sequence_output.samples:
|
||||||
assert nth_output.output_token == i
|
assert nth_output.output_token == i
|
||||||
@ -123,6 +121,58 @@ def test_sampler_all_random(seed: int, device: str):
|
|||||||
del model_runner
|
del model_runner
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||||
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
def test_sampler_all_random_seed(seed: int, device: str):
|
||||||
|
set_random_seed(seed)
|
||||||
|
torch.set_default_device(device)
|
||||||
|
batch_size = random.randint(1, 256)
|
||||||
|
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
|
||||||
|
batch_size)
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
fake_logits[i, i] = 1e2
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=1.0,
|
||||||
|
n=random.randint(1, 10),
|
||||||
|
seed=random.randint(0, 10000),
|
||||||
|
)
|
||||||
|
sampler_output = _do_sample(batch_size, input_tensor, sampler,
|
||||||
|
model_runner, sampling_params)
|
||||||
|
|
||||||
|
for i, sequence_output in enumerate(sampler_output):
|
||||||
|
for nth_output in sequence_output.samples:
|
||||||
|
assert nth_output.output_token == i
|
||||||
|
|
||||||
|
del model_runner
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||||
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
def test_sampler_all_random_seed_deterministic(seed: int, device: str):
|
||||||
|
set_random_seed(seed)
|
||||||
|
torch.set_default_device(device)
|
||||||
|
batch_size = random.randint(1, 256)
|
||||||
|
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
|
||||||
|
batch_size)
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=1.0,
|
||||||
|
n=random.randint(1, 10),
|
||||||
|
seed=random.randint(0, 10000),
|
||||||
|
)
|
||||||
|
first_sampler_output = _do_sample(batch_size, input_tensor, sampler,
|
||||||
|
model_runner, sampling_params)
|
||||||
|
|
||||||
|
second_sampler_output = _do_sample(batch_size, input_tensor, sampler,
|
||||||
|
model_runner, sampling_params)
|
||||||
|
|
||||||
|
assert first_sampler_output == second_sampler_output
|
||||||
|
|
||||||
|
del model_runner
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
def test_sampler_all_beam(seed: int, device: str):
|
def test_sampler_all_beam(seed: int, device: str):
|
||||||
@ -131,29 +181,13 @@ def test_sampler_all_beam(seed: int, device: str):
|
|||||||
batch_size = random.randint(1, 256)
|
batch_size = random.randint(1, 256)
|
||||||
input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
|
input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
|
||||||
|
|
||||||
seq_group_metadata_list = []
|
sampling_params = SamplingParams(
|
||||||
prompt_lens = []
|
temperature=0,
|
||||||
for i in range(batch_size):
|
best_of=2,
|
||||||
seq_group_metadata_list.append(
|
use_beam_search=True,
|
||||||
SequenceGroupMetadata(
|
)
|
||||||
request_id=f"test_{i}",
|
_do_sample(batch_size, input_tensor, sampler, model_runner,
|
||||||
is_prompt=True,
|
sampling_params)
|
||||||
seq_data={0: SequenceData([1, 2, 3])},
|
|
||||||
sampling_params=SamplingParams(
|
|
||||||
temperature=0,
|
|
||||||
best_of=2,
|
|
||||||
use_beam_search=True,
|
|
||||||
),
|
|
||||||
block_tables={0: [1]},
|
|
||||||
))
|
|
||||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
|
||||||
|
|
||||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
|
||||||
prompt_lens,
|
|
||||||
subquery_lens=prompt_lens)
|
|
||||||
sampler(embedding=None,
|
|
||||||
hidden_states=input_tensor,
|
|
||||||
sampling_metadata=sampling_metadata)
|
|
||||||
# no assertion here as I am not sure how to determine whether
|
# no assertion here as I am not sure how to determine whether
|
||||||
# the outputs are expected - in other words, this just tests
|
# the outputs are expected - in other words, this just tests
|
||||||
# whether there are no exceptions in the sampler
|
# whether there are no exceptions in the sampler
|
||||||
@ -171,14 +205,15 @@ def test_sampler_mixed(seed: int, device: str):
|
|||||||
batch_size)
|
batch_size)
|
||||||
|
|
||||||
seq_group_metadata_list = []
|
seq_group_metadata_list = []
|
||||||
expected_tokens = []
|
expected_tokens: List[Optional[List[int]]] = []
|
||||||
prompt_lens = []
|
prompt_lens = []
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
n = 1
|
expected: Optional[List[int]] = None
|
||||||
sampling_type = random.randint(0, 2)
|
sampling_type = random.randint(0, 3)
|
||||||
if sampling_type == 0:
|
if sampling_type == 0:
|
||||||
sampling_params = SamplingParams(temperature=0)
|
sampling_params = SamplingParams(temperature=0)
|
||||||
elif sampling_type == 1:
|
expected = [torch.argmax(fake_logits[i], dim=-1).item()]
|
||||||
|
elif sampling_type in (1, 2):
|
||||||
n = random.randint(1, 10)
|
n = random.randint(1, 10)
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
temperature=random.random() + 0.1,
|
temperature=random.random() + 0.1,
|
||||||
@ -187,13 +222,17 @@ def test_sampler_mixed(seed: int, device: str):
|
|||||||
n=n,
|
n=n,
|
||||||
presence_penalty=random.randint(0, 1),
|
presence_penalty=random.randint(0, 1),
|
||||||
)
|
)
|
||||||
|
if sampling_type == 2:
|
||||||
|
sampling_params.seed = random.randint(0, 10000)
|
||||||
|
else:
|
||||||
|
for idx in range(n):
|
||||||
|
fake_logits[i, i + idx] = 1e2
|
||||||
|
expected = list(range(i, i + n))
|
||||||
else:
|
else:
|
||||||
sampling_params = SamplingParams(temperature=0,
|
sampling_params = SamplingParams(temperature=0,
|
||||||
use_beam_search=True,
|
use_beam_search=True,
|
||||||
best_of=2)
|
best_of=2)
|
||||||
for idx in range(n):
|
expected_tokens.append(expected)
|
||||||
fake_logits[i, i + idx] = 1e2
|
|
||||||
expected_tokens.append(i + idx)
|
|
||||||
seq_group_metadata_list.append(
|
seq_group_metadata_list.append(
|
||||||
SequenceGroupMetadata(
|
SequenceGroupMetadata(
|
||||||
request_id=f"test_{i}",
|
request_id=f"test_{i}",
|
||||||
@ -204,17 +243,50 @@ def test_sampler_mixed(seed: int, device: str):
|
|||||||
))
|
))
|
||||||
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||||
|
|
||||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
def test_sampling(model_runner: ModelRunner):
|
||||||
prompt_lens,
|
sampling_metadata = model_runner._prepare_sample(
|
||||||
subquery_lens=prompt_lens)
|
seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens)
|
||||||
sampler_output = sampler(embedding=None,
|
sampler_output = sampler(embedding=None,
|
||||||
hidden_states=input_tensor,
|
hidden_states=input_tensor,
|
||||||
sampling_metadata=sampling_metadata)
|
sampling_metadata=sampling_metadata)
|
||||||
for i, sequence_output in enumerate(sampler_output):
|
|
||||||
if seq_group_metadata_list[i].sampling_params.use_beam_search:
|
for i, (sequence_output, metadata) in enumerate(
|
||||||
continue
|
zip(sampler_output, seq_group_metadata_list)):
|
||||||
for nth_output in sequence_output.samples:
|
if metadata.sampling_params.use_beam_search:
|
||||||
assert nth_output.output_token in expected_tokens
|
continue
|
||||||
|
|
||||||
|
if metadata.sampling_params.seed is not None \
|
||||||
|
and expected_tokens[i] is None:
|
||||||
|
# Record seeded random result to compare with results of second invocation
|
||||||
|
expected_tokens[i] = [
|
||||||
|
nth_output.output_token
|
||||||
|
for nth_output in sequence_output.samples
|
||||||
|
]
|
||||||
|
continue
|
||||||
|
|
||||||
|
for n, nth_output in enumerate(sequence_output.samples):
|
||||||
|
if metadata.sampling_params.temperature == 0 or metadata.sampling_params.seed is not None:
|
||||||
|
# Ensure exact matches for greedy or random with seed
|
||||||
|
assert nth_output.output_token == expected_tokens[i][n]
|
||||||
|
else:
|
||||||
|
# For non-seeded random check that one of the high-logit tokens were chosen
|
||||||
|
assert nth_output.output_token in expected_tokens[i]
|
||||||
|
|
||||||
|
# Test batch
|
||||||
|
test_sampling(model_runner)
|
||||||
|
|
||||||
|
# Shuffle the batch and resample
|
||||||
|
target_index = list(range(batch_size))
|
||||||
|
for list_to_shuffle in (target_index, seq_group_metadata_list,
|
||||||
|
expected_tokens, prompt_lens):
|
||||||
|
random.Random(seed).shuffle(list_to_shuffle)
|
||||||
|
target_index = torch.tensor(target_index)
|
||||||
|
input_tensor.data = input_tensor.index_select(0, target_index)
|
||||||
|
fake_logits.data = fake_logits.index_select(0, target_index)
|
||||||
|
|
||||||
|
# This time, results of seeded random samples will be compared with the corresponding
|
||||||
|
# sample in the pre-shuffled batch
|
||||||
|
test_sampling(model_runner)
|
||||||
|
|
||||||
del model_runner
|
del model_runner
|
||||||
|
|
||||||
|
|||||||
82
tests/samplers/test_seeded_generate.py
Normal file
82
tests/samplers/test_seeded_generate.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
"""Verify that seeded random sampling is deterministic.
|
||||||
|
|
||||||
|
Run `pytest tests/samplers/test_seeded_generate.py --forked`.
|
||||||
|
"""
|
||||||
|
import copy
|
||||||
|
import random
|
||||||
|
from itertools import combinations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.model_executor.utils import set_random_seed
|
||||||
|
from vllm import SamplingParams
|
||||||
|
|
||||||
|
MODEL = "facebook/opt-125m"
|
||||||
|
RANDOM_SEEDS = list(range(5))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def vllm_model(vllm_runner):
|
||||||
|
vllm_model = vllm_runner(MODEL, dtype="half")
|
||||||
|
yield vllm_model
|
||||||
|
del vllm_model
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||||
|
def test_random_sample_with_seed(
|
||||||
|
vllm_model,
|
||||||
|
example_prompts,
|
||||||
|
seed: int,
|
||||||
|
) -> None:
|
||||||
|
set_random_seed(seed)
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
# Parameters to ensure sufficient randomness
|
||||||
|
temperature=2.0,
|
||||||
|
top_p=min(random.random() + 0.3, 1),
|
||||||
|
top_k=random.randint(5, 20),
|
||||||
|
n=random.randint(1, 10),
|
||||||
|
presence_penalty=random.randint(0, 1),
|
||||||
|
max_tokens=8,
|
||||||
|
ignore_eos=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
sampling_params_seed_1 = copy.deepcopy(sampling_params)
|
||||||
|
sampling_params_seed_1.seed = 100
|
||||||
|
sampling_params_seed_2 = copy.deepcopy(sampling_params)
|
||||||
|
sampling_params_seed_2.seed = 200
|
||||||
|
|
||||||
|
llm = vllm_model.model
|
||||||
|
|
||||||
|
for prompt in example_prompts:
|
||||||
|
for params in (
|
||||||
|
sampling_params,
|
||||||
|
sampling_params_seed_1,
|
||||||
|
sampling_params_seed_2,
|
||||||
|
sampling_params,
|
||||||
|
sampling_params_seed_1,
|
||||||
|
sampling_params_seed_2,
|
||||||
|
):
|
||||||
|
llm._add_request(
|
||||||
|
prompt=prompt,
|
||||||
|
prompt_token_ids=None,
|
||||||
|
sampling_params=params,
|
||||||
|
)
|
||||||
|
|
||||||
|
results = llm._run_engine(use_tqdm=False)
|
||||||
|
all_outputs = [[out.token_ids for out in output.outputs]
|
||||||
|
for output in results]
|
||||||
|
|
||||||
|
for i in range(0, len(example_prompts), 6):
|
||||||
|
outputs = all_outputs[i:i + 6]
|
||||||
|
|
||||||
|
# verify all non-seeded requests differ
|
||||||
|
for output_a, output_b in combinations(
|
||||||
|
(outputs[0], outputs[1], outputs[2], outputs[3]),
|
||||||
|
2,
|
||||||
|
):
|
||||||
|
assert output_a != output_b
|
||||||
|
|
||||||
|
# verify requests with the same seed match
|
||||||
|
assert outputs[1] == outputs[4]
|
||||||
|
assert outputs[2] == outputs[5]
|
||||||
@ -387,6 +387,7 @@ class Scheduler:
|
|||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
lora_request=seq_group.lora_request,
|
lora_request=seq_group.lora_request,
|
||||||
prefix=seq_group.prefix,
|
prefix=seq_group.prefix,
|
||||||
|
state=seq_group.state,
|
||||||
)
|
)
|
||||||
seq_group_metadata_list.append(seq_group_metadata)
|
seq_group_metadata_list.append(seq_group_metadata)
|
||||||
return seq_group_metadata_list, scheduler_outputs
|
return seq_group_metadata_list, scheduler_outputs
|
||||||
|
|||||||
@ -173,7 +173,6 @@ class EngineArgs:
|
|||||||
default=EngineArgs.block_size,
|
default=EngineArgs.block_size,
|
||||||
choices=[8, 16, 32],
|
choices=[8, 16, 32],
|
||||||
help='token block size')
|
help='token block size')
|
||||||
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
|
|
||||||
parser.add_argument('--seed',
|
parser.add_argument('--seed',
|
||||||
type=int,
|
type=int,
|
||||||
default=EngineArgs.seed,
|
default=EngineArgs.seed,
|
||||||
|
|||||||
@ -60,6 +60,7 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
top_p: Optional[float] = 1.0
|
top_p: Optional[float] = 1.0
|
||||||
n: Optional[int] = 1
|
n: Optional[int] = 1
|
||||||
max_tokens: Optional[int] = None
|
max_tokens: Optional[int] = None
|
||||||
|
seed: Optional[int] = None
|
||||||
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
presence_penalty: Optional[float] = 0.0
|
presence_penalty: Optional[float] = 0.0
|
||||||
@ -90,6 +91,7 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
temperature=self.temperature,
|
temperature=self.temperature,
|
||||||
top_p=self.top_p,
|
top_p=self.top_p,
|
||||||
min_p=self.min_p,
|
min_p=self.min_p,
|
||||||
|
seed=self.seed,
|
||||||
stop=self.stop,
|
stop=self.stop,
|
||||||
stop_token_ids=self.stop_token_ids,
|
stop_token_ids=self.stop_token_ids,
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
@ -117,6 +119,7 @@ class CompletionRequest(BaseModel):
|
|||||||
logprobs: Optional[int] = None
|
logprobs: Optional[int] = None
|
||||||
echo: Optional[bool] = False
|
echo: Optional[bool] = False
|
||||||
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||||
|
seed: Optional[int] = None
|
||||||
presence_penalty: Optional[float] = 0.0
|
presence_penalty: Optional[float] = 0.0
|
||||||
frequency_penalty: Optional[float] = 0.0
|
frequency_penalty: Optional[float] = 0.0
|
||||||
best_of: Optional[int] = None
|
best_of: Optional[int] = None
|
||||||
@ -147,6 +150,7 @@ class CompletionRequest(BaseModel):
|
|||||||
top_p=self.top_p,
|
top_p=self.top_p,
|
||||||
top_k=self.top_k,
|
top_k=self.top_k,
|
||||||
min_p=self.min_p,
|
min_p=self.min_p,
|
||||||
|
seed=self.seed,
|
||||||
stop=self.stop,
|
stop=self.stop,
|
||||||
stop_token_ids=self.stop_token_ids,
|
stop_token_ids=self.stop_token_ids,
|
||||||
ignore_eos=self.ignore_eos,
|
ignore_eos=self.ignore_eos,
|
||||||
|
|||||||
@ -342,7 +342,9 @@ def _beam_search_sample(
|
|||||||
def _multinomial(
|
def _multinomial(
|
||||||
probs: torch.Tensor,
|
probs: torch.Tensor,
|
||||||
num_samples: int,
|
num_samples: int,
|
||||||
):
|
seq_groups: Optional[List[Tuple[List[int], SamplingParams]]] = None,
|
||||||
|
generators: Optional[List[torch.Generator]] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
if num_samples > 1:
|
if num_samples > 1:
|
||||||
# This is equivalent to torch.repeat_interleaved (which also
|
# This is equivalent to torch.repeat_interleaved (which also
|
||||||
# forces a GPU<->CPU sync).
|
# forces a GPU<->CPU sync).
|
||||||
@ -352,7 +354,15 @@ def _multinomial(
|
|||||||
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
|
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
|
||||||
probs.shape[1]).contiguous().view(
|
probs.shape[1]).contiguous().view(
|
||||||
-1, probs.shape[1])
|
-1, probs.shape[1])
|
||||||
q = torch.empty_like(probs).exponential_(1)
|
q = torch.empty_like(probs)
|
||||||
|
if seq_groups is None:
|
||||||
|
q.exponential_()
|
||||||
|
else:
|
||||||
|
sample_idx = 0
|
||||||
|
for (seq_ids, _), generator in zip(seq_groups, generators):
|
||||||
|
next_sample_idx = sample_idx + len(seq_ids) * num_samples
|
||||||
|
q[sample_idx:next_sample_idx].exponential_(generator=generator)
|
||||||
|
sample_idx = next_sample_idx
|
||||||
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
|
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
|
||||||
|
|
||||||
|
|
||||||
@ -370,6 +380,7 @@ def _sample(
|
|||||||
|
|
||||||
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
|
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
|
||||||
sample_metadata = {}
|
sample_metadata = {}
|
||||||
|
multinomial_samples = {}
|
||||||
|
|
||||||
# Counterintiutively, having two loops here is actually faster.
|
# Counterintiutively, having two loops here is actually faster.
|
||||||
# The first loop can run without waiting on GPU<->CPU sync.
|
# The first loop can run without waiting on GPU<->CPU sync.
|
||||||
@ -385,14 +396,18 @@ def _sample(
|
|||||||
is_prompts, sample_indices)
|
is_prompts, sample_indices)
|
||||||
if sampling_type == SamplingType.GREEDY:
|
if sampling_type == SamplingType.GREEDY:
|
||||||
greedy_samples = torch.argmax(logprobs[sample_indices], dim=-1)
|
greedy_samples = torch.argmax(logprobs[sample_indices], dim=-1)
|
||||||
elif sampling_type == SamplingType.RANDOM:
|
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
||||||
max_best_of = 1
|
max_best_of = 1
|
||||||
for seq_group, is_prompt in zip(seq_groups, is_prompts):
|
for seq_group, is_prompt in zip(seq_groups, is_prompts):
|
||||||
if is_prompt:
|
if is_prompt:
|
||||||
_, sampling_params = seq_group
|
_, sampling_params = seq_group
|
||||||
max_best_of = max(max_best_of, sampling_params.best_of)
|
max_best_of = max(max_best_of, sampling_params.best_of)
|
||||||
multinomial_samples = _multinomial(probs[sample_indices],
|
seeded_args = {} if sampling_type == SamplingType.RANDOM else {
|
||||||
max_best_of)
|
"seq_groups": seq_groups,
|
||||||
|
"generators": sampling_metadata.generators,
|
||||||
|
}
|
||||||
|
multinomial_samples[sampling_type] = _multinomial(
|
||||||
|
probs[sample_indices], max_best_of, **seeded_args)
|
||||||
elif sampling_type == SamplingType.BEAM:
|
elif sampling_type == SamplingType.BEAM:
|
||||||
beam_search_logprobs = logprobs[sample_indices]
|
beam_search_logprobs = logprobs[sample_indices]
|
||||||
else:
|
else:
|
||||||
@ -407,9 +422,9 @@ def _sample(
|
|||||||
sampling_type]
|
sampling_type]
|
||||||
if sampling_type == SamplingType.GREEDY:
|
if sampling_type == SamplingType.GREEDY:
|
||||||
sample_results = _greedy_sample(seq_groups, greedy_samples)
|
sample_results = _greedy_sample(seq_groups, greedy_samples)
|
||||||
elif sampling_type == SamplingType.RANDOM:
|
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
||||||
sample_results = _random_sample(seq_groups, is_prompts,
|
sample_results = _random_sample(seq_groups, is_prompts,
|
||||||
multinomial_samples)
|
multinomial_samples[sampling_type])
|
||||||
elif sampling_type == SamplingType.BEAM:
|
elif sampling_type == SamplingType.BEAM:
|
||||||
sample_results = _beam_search_sample(seq_groups, is_prompts,
|
sample_results = _beam_search_sample(seq_groups, is_prompts,
|
||||||
sampling_metadata.seq_data,
|
sampling_metadata.seq_data,
|
||||||
|
|||||||
@ -19,6 +19,7 @@ class SamplingMetadata:
|
|||||||
prompt_lens: Lengths of prompts.
|
prompt_lens: Lengths of prompts.
|
||||||
selected_token_indices: Token indices selected for sampling.
|
selected_token_indices: Token indices selected for sampling.
|
||||||
categorized_sample_indices: SamplingType -> token indices to sample.
|
categorized_sample_indices: SamplingType -> token indices to sample.
|
||||||
|
generators: List of torch.Generators to use for seeded sampling
|
||||||
perform_sampling: Whether to perform sampling. This option is used to
|
perform_sampling: Whether to perform sampling. This option is used to
|
||||||
make the sampling only happens in the driver worker, and disable
|
make the sampling only happens in the driver worker, and disable
|
||||||
sampling in other worker processes.
|
sampling in other worker processes.
|
||||||
@ -31,6 +32,7 @@ class SamplingMetadata:
|
|||||||
prompt_lens: Optional[List[int]],
|
prompt_lens: Optional[List[int]],
|
||||||
selected_token_indices: torch.Tensor,
|
selected_token_indices: torch.Tensor,
|
||||||
categorized_sample_indices: Optional[Dict[SamplingType, torch.Tensor]],
|
categorized_sample_indices: Optional[Dict[SamplingType, torch.Tensor]],
|
||||||
|
generators: Optional[List[torch.Generator]] = None,
|
||||||
perform_sampling: bool = True,
|
perform_sampling: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.seq_groups = seq_groups
|
self.seq_groups = seq_groups
|
||||||
@ -38,6 +40,7 @@ class SamplingMetadata:
|
|||||||
self.prompt_lens = prompt_lens
|
self.prompt_lens = prompt_lens
|
||||||
self.selected_token_indices = selected_token_indices
|
self.selected_token_indices = selected_token_indices
|
||||||
self.categorized_sample_indices = categorized_sample_indices
|
self.categorized_sample_indices = categorized_sample_indices
|
||||||
|
self.generators = generators
|
||||||
self.perform_sampling = perform_sampling
|
self.perform_sampling = perform_sampling
|
||||||
|
|
||||||
self.num_prompts = len(prompt_lens) if prompt_lens is not None else 0
|
self.num_prompts = len(prompt_lens) if prompt_lens is not None else 0
|
||||||
|
|||||||
@ -11,7 +11,8 @@ _SAMPLING_EPS = 1e-5
|
|||||||
class SamplingType(IntEnum):
|
class SamplingType(IntEnum):
|
||||||
GREEDY = 0
|
GREEDY = 0
|
||||||
RANDOM = 1
|
RANDOM = 1
|
||||||
BEAM = 2
|
RANDOM_SEED = 2
|
||||||
|
BEAM = 3
|
||||||
|
|
||||||
|
|
||||||
LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor]
|
LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor]
|
||||||
@ -56,6 +57,7 @@ class SamplingParams:
|
|||||||
min_p: Float that represents the minimum probability for a token to be
|
min_p: Float that represents the minimum probability for a token to be
|
||||||
considered, relative to the probability of the most likely token.
|
considered, relative to the probability of the most likely token.
|
||||||
Must be in [0, 1]. Set to 0 to disable this.
|
Must be in [0, 1]. Set to 0 to disable this.
|
||||||
|
seed: Random seed to use for the generation.
|
||||||
use_beam_search: Whether to use beam search instead of sampling.
|
use_beam_search: Whether to use beam search instead of sampling.
|
||||||
length_penalty: Float that penalizes sequences based on their length.
|
length_penalty: Float that penalizes sequences based on their length.
|
||||||
Used in beam search.
|
Used in beam search.
|
||||||
@ -101,6 +103,7 @@ class SamplingParams:
|
|||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
top_k: int = -1,
|
top_k: int = -1,
|
||||||
min_p: float = 0.0,
|
min_p: float = 0.0,
|
||||||
|
seed: Optional[int] = None,
|
||||||
use_beam_search: bool = False,
|
use_beam_search: bool = False,
|
||||||
length_penalty: float = 1.0,
|
length_penalty: float = 1.0,
|
||||||
early_stopping: Union[bool, str] = False,
|
early_stopping: Union[bool, str] = False,
|
||||||
@ -124,6 +127,7 @@ class SamplingParams:
|
|||||||
self.top_p = top_p
|
self.top_p = top_p
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
self.min_p = min_p
|
self.min_p = min_p
|
||||||
|
self.seed = seed
|
||||||
self.use_beam_search = use_beam_search
|
self.use_beam_search = use_beam_search
|
||||||
self.length_penalty = length_penalty
|
self.length_penalty = length_penalty
|
||||||
self.early_stopping = early_stopping
|
self.early_stopping = early_stopping
|
||||||
@ -229,6 +233,8 @@ class SamplingParams:
|
|||||||
return SamplingType.BEAM
|
return SamplingType.BEAM
|
||||||
if self.temperature < _SAMPLING_EPS:
|
if self.temperature < _SAMPLING_EPS:
|
||||||
return SamplingType.GREEDY
|
return SamplingType.GREEDY
|
||||||
|
if self.seed is not None:
|
||||||
|
return SamplingType.RANDOM_SEED
|
||||||
return SamplingType.RANDOM
|
return SamplingType.RANDOM
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
@ -242,6 +248,7 @@ class SamplingParams:
|
|||||||
f"top_p={self.top_p}, "
|
f"top_p={self.top_p}, "
|
||||||
f"top_k={self.top_k}, "
|
f"top_k={self.top_k}, "
|
||||||
f"min_p={self.min_p}, "
|
f"min_p={self.min_p}, "
|
||||||
|
f"seed={self.seed}, "
|
||||||
f"use_beam_search={self.use_beam_search}, "
|
f"use_beam_search={self.use_beam_search}, "
|
||||||
f"length_penalty={self.length_penalty}, "
|
f"length_penalty={self.length_penalty}, "
|
||||||
f"early_stopping={self.early_stopping}, "
|
f"early_stopping={self.early_stopping}, "
|
||||||
|
|||||||
@ -248,6 +248,14 @@ class Sequence:
|
|||||||
f"num_blocks={len(self.logical_token_blocks)})")
|
f"num_blocks={len(self.logical_token_blocks)})")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SequenceGroupState:
|
||||||
|
"""Mutable state tied to a specific sequence group"""
|
||||||
|
|
||||||
|
# torch.Generator used in seeded sampling
|
||||||
|
generator: Optional = None
|
||||||
|
|
||||||
|
|
||||||
class SequenceGroup:
|
class SequenceGroup:
|
||||||
"""A group of sequences that are generated from the same prompt.
|
"""A group of sequences that are generated from the same prompt.
|
||||||
|
|
||||||
@ -280,6 +288,7 @@ class SequenceGroup:
|
|||||||
self.lora_request = lora_request
|
self.lora_request = lora_request
|
||||||
self.prefix: Optional[Prefix] = prefix
|
self.prefix: Optional[Prefix] = prefix
|
||||||
self.prompt_logprobs: Optional[PromptLogprobs] = None
|
self.prompt_logprobs: Optional[PromptLogprobs] = None
|
||||||
|
self.state = SequenceGroupState()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def prompt(self) -> str:
|
def prompt(self) -> str:
|
||||||
@ -397,6 +406,7 @@ class SequenceGroupMetadata:
|
|||||||
sampling_params: The sampling parameters used to generate the outputs.
|
sampling_params: The sampling parameters used to generate the outputs.
|
||||||
block_tables: The block tables. (Seq id -> list of physical block
|
block_tables: The block tables. (Seq id -> list of physical block
|
||||||
numbers)
|
numbers)
|
||||||
|
state: Internal state tied to this sequence group.
|
||||||
lora_request: LoRA request.
|
lora_request: LoRA request.
|
||||||
prefix: The prefix of the prompt of the sequence group.
|
prefix: The prefix of the prompt of the sequence group.
|
||||||
"""
|
"""
|
||||||
@ -410,6 +420,7 @@ class SequenceGroupMetadata:
|
|||||||
block_tables: Dict[int, List[int]],
|
block_tables: Dict[int, List[int]],
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
prefix: Optional[Prefix] = None,
|
prefix: Optional[Prefix] = None,
|
||||||
|
state: Optional[SequenceGroupState] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
self.is_prompt = is_prompt
|
self.is_prompt = is_prompt
|
||||||
@ -418,6 +429,7 @@ class SequenceGroupMetadata:
|
|||||||
self.block_tables = block_tables
|
self.block_tables = block_tables
|
||||||
self.lora_request = lora_request
|
self.lora_request = lora_request
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
|
self.state = SequenceGroupState() if state is None else state
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def lora_int_id(self) -> int:
|
def lora_int_id(self) -> int:
|
||||||
|
|||||||
@ -389,6 +389,7 @@ class ModelRunner:
|
|||||||
) -> SamplingMetadata:
|
) -> SamplingMetadata:
|
||||||
seq_groups: List[Tuple[List[int], SamplingParams]] = []
|
seq_groups: List[Tuple[List[int], SamplingParams]] = []
|
||||||
selected_token_indices: List[int] = []
|
selected_token_indices: List[int] = []
|
||||||
|
generators: List[torch.Generator] = []
|
||||||
selected_token_start_idx = 0
|
selected_token_start_idx = 0
|
||||||
categorized_sample_indices = {t: [] for t in SamplingType}
|
categorized_sample_indices = {t: [] for t in SamplingType}
|
||||||
categorized_sample_indices_start_idx = 0
|
categorized_sample_indices_start_idx = 0
|
||||||
@ -419,6 +420,10 @@ class ModelRunner:
|
|||||||
selected_token_indices.append(selected_token_start_idx +
|
selected_token_indices.append(selected_token_start_idx +
|
||||||
subquery_len - 1)
|
subquery_len - 1)
|
||||||
selected_token_start_idx += max_subquery_len
|
selected_token_start_idx += max_subquery_len
|
||||||
|
|
||||||
|
if sampling_params.seed is not None:
|
||||||
|
seq_group_metadata.state.generator = torch.Generator(
|
||||||
|
device="cuda").manual_seed(sampling_params.seed)
|
||||||
else:
|
else:
|
||||||
num_seqs = len(seq_ids)
|
num_seqs = len(seq_ids)
|
||||||
selected_token_indices.extend(
|
selected_token_indices.extend(
|
||||||
@ -432,6 +437,9 @@ class ModelRunner:
|
|||||||
categorized_sample_indices_start_idx + num_seqs))
|
categorized_sample_indices_start_idx + num_seqs))
|
||||||
categorized_sample_indices_start_idx += num_seqs
|
categorized_sample_indices_start_idx += num_seqs
|
||||||
|
|
||||||
|
if sampling_params.seed is not None:
|
||||||
|
generators.append(seq_group_metadata.state.generator)
|
||||||
|
|
||||||
selected_token_indices = _async_h2d(selected_token_indices,
|
selected_token_indices = _async_h2d(selected_token_indices,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
target_device=self.device,
|
target_device=self.device,
|
||||||
@ -454,6 +462,7 @@ class ModelRunner:
|
|||||||
prompt_lens=prompt_lens,
|
prompt_lens=prompt_lens,
|
||||||
selected_token_indices=selected_token_indices,
|
selected_token_indices=selected_token_indices,
|
||||||
categorized_sample_indices=categorized_sample_indices,
|
categorized_sample_indices=categorized_sample_indices,
|
||||||
|
generators=generators,
|
||||||
)
|
)
|
||||||
return sampling_metadata
|
return sampling_metadata
|
||||||
|
|
||||||
@ -536,6 +545,7 @@ class ModelRunner:
|
|||||||
prompt_lens=None,
|
prompt_lens=None,
|
||||||
selected_token_indices=metadata_dict["selected_token_indices"],
|
selected_token_indices=metadata_dict["selected_token_indices"],
|
||||||
categorized_sample_indices=None,
|
categorized_sample_indices=None,
|
||||||
|
generators=None,
|
||||||
perform_sampling=False,
|
perform_sampling=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user