mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 04:45:01 +08:00
Fix broken sampler tests (#1896)
Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
parent
4cefa9b49b
commit
5f09cbdb63
@ -8,7 +8,7 @@ import torch
|
|||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.utils import set_random_seed
|
from vllm.model_executor.utils import set_random_seed
|
||||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||||
from vllm.worker.worker import Worker
|
from vllm.worker.model_runner import ModelRunner
|
||||||
|
|
||||||
|
|
||||||
class MockLogitsSampler(Sampler):
|
class MockLogitsSampler(Sampler):
|
||||||
@ -27,7 +27,7 @@ class MockLogitsSampler(Sampler):
|
|||||||
|
|
||||||
def _prepare_test(
|
def _prepare_test(
|
||||||
batch_size: int
|
batch_size: int
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, Worker]:
|
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, ModelRunner]:
|
||||||
vocab_size = 32000
|
vocab_size = 32000
|
||||||
input_tensor = torch.rand((batch_size, 1024),
|
input_tensor = torch.rand((batch_size, 1024),
|
||||||
device="cuda",
|
device="cuda",
|
||||||
@ -37,9 +37,8 @@ def _prepare_test(
|
|||||||
device=input_tensor.device,
|
device=input_tensor.device,
|
||||||
dtype=input_tensor.dtype)
|
dtype=input_tensor.dtype)
|
||||||
sampler = MockLogitsSampler(32000, fake_logits)
|
sampler = MockLogitsSampler(32000, fake_logits)
|
||||||
worker = Worker(None, None, None)
|
model_runner = ModelRunner(None, None, None)
|
||||||
worker.block_size = 16
|
return input_tensor, fake_logits, sampler, model_runner
|
||||||
return input_tensor, fake_logits, sampler, worker
|
|
||||||
|
|
||||||
|
|
||||||
RANDOM_SEEDS = list(range(128))
|
RANDOM_SEEDS = list(range(128))
|
||||||
@ -49,9 +48,11 @@ RANDOM_SEEDS = list(range(128))
|
|||||||
def test_sampler_all_greedy(seed: int):
|
def test_sampler_all_greedy(seed: int):
|
||||||
set_random_seed(seed)
|
set_random_seed(seed)
|
||||||
batch_size = random.randint(1, 256)
|
batch_size = random.randint(1, 256)
|
||||||
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
|
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
|
||||||
|
batch_size)
|
||||||
|
|
||||||
seq_group_metadata_list = []
|
seq_group_metadata_list = []
|
||||||
|
prompt_lens = []
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
seq_group_metadata_list.append(
|
seq_group_metadata_list.append(
|
||||||
SequenceGroupMetadata(
|
SequenceGroupMetadata(
|
||||||
@ -61,11 +62,13 @@ def test_sampler_all_greedy(seed: int):
|
|||||||
sampling_params=SamplingParams(temperature=0, ),
|
sampling_params=SamplingParams(temperature=0, ),
|
||||||
block_tables={0: [1]},
|
block_tables={0: [1]},
|
||||||
))
|
))
|
||||||
|
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||||
|
|
||||||
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
|
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||||
|
prompt_lens)
|
||||||
sampler_output = sampler(embedding=None,
|
sampler_output = sampler(embedding=None,
|
||||||
hidden_states=input_tensor,
|
hidden_states=input_tensor,
|
||||||
input_metadata=input_metadata)
|
sampling_metadata=sampling_metadata)
|
||||||
expected = torch.argmax(fake_logits, dim=-1)
|
expected = torch.argmax(fake_logits, dim=-1)
|
||||||
for i, sequence_output in enumerate(sampler_output):
|
for i, sequence_output in enumerate(sampler_output):
|
||||||
for nth_output in sequence_output.samples:
|
for nth_output in sequence_output.samples:
|
||||||
@ -76,12 +79,14 @@ def test_sampler_all_greedy(seed: int):
|
|||||||
def test_sampler_all_random(seed: int):
|
def test_sampler_all_random(seed: int):
|
||||||
set_random_seed(seed)
|
set_random_seed(seed)
|
||||||
batch_size = random.randint(1, 256)
|
batch_size = random.randint(1, 256)
|
||||||
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
|
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
|
||||||
|
batch_size)
|
||||||
|
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
fake_logits[i, i] = 1e2
|
fake_logits[i, i] = 1e2
|
||||||
|
|
||||||
seq_group_metadata_list = []
|
seq_group_metadata_list = []
|
||||||
|
prompt_lens = []
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
seq_group_metadata_list.append(
|
seq_group_metadata_list.append(
|
||||||
SequenceGroupMetadata(
|
SequenceGroupMetadata(
|
||||||
@ -94,11 +99,13 @@ def test_sampler_all_random(seed: int):
|
|||||||
),
|
),
|
||||||
block_tables={0: [1]},
|
block_tables={0: [1]},
|
||||||
))
|
))
|
||||||
|
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||||
|
|
||||||
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
|
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||||
|
prompt_lens)
|
||||||
sampler_output = sampler(embedding=None,
|
sampler_output = sampler(embedding=None,
|
||||||
hidden_states=input_tensor,
|
hidden_states=input_tensor,
|
||||||
input_metadata=input_metadata)
|
sampling_metadata=sampling_metadata)
|
||||||
for i, sequence_output in enumerate(sampler_output):
|
for i, sequence_output in enumerate(sampler_output):
|
||||||
for nth_output in sequence_output.samples:
|
for nth_output in sequence_output.samples:
|
||||||
assert nth_output.output_token == i
|
assert nth_output.output_token == i
|
||||||
@ -108,9 +115,10 @@ def test_sampler_all_random(seed: int):
|
|||||||
def test_sampler_all_beam(seed: int):
|
def test_sampler_all_beam(seed: int):
|
||||||
set_random_seed(seed)
|
set_random_seed(seed)
|
||||||
batch_size = random.randint(1, 256)
|
batch_size = random.randint(1, 256)
|
||||||
input_tensor, _, sampler, worker = _prepare_test(batch_size)
|
input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
|
||||||
|
|
||||||
seq_group_metadata_list = []
|
seq_group_metadata_list = []
|
||||||
|
prompt_lens = []
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
seq_group_metadata_list.append(
|
seq_group_metadata_list.append(
|
||||||
SequenceGroupMetadata(
|
SequenceGroupMetadata(
|
||||||
@ -124,11 +132,13 @@ def test_sampler_all_beam(seed: int):
|
|||||||
),
|
),
|
||||||
block_tables={0: [1]},
|
block_tables={0: [1]},
|
||||||
))
|
))
|
||||||
|
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||||
|
|
||||||
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
|
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||||
|
prompt_lens)
|
||||||
sampler(embedding=None,
|
sampler(embedding=None,
|
||||||
hidden_states=input_tensor,
|
hidden_states=input_tensor,
|
||||||
input_metadata=input_metadata)
|
sampling_metadata=sampling_metadata)
|
||||||
# no assertion here as I am not sure how to determine whether
|
# no assertion here as I am not sure how to determine whether
|
||||||
# the outputs are expected - in other words, this just tests
|
# the outputs are expected - in other words, this just tests
|
||||||
# whether there are no exceptions in the sampler
|
# whether there are no exceptions in the sampler
|
||||||
@ -139,10 +149,12 @@ def test_sampler_all_beam(seed: int):
|
|||||||
def test_sampler_mixed(seed: int):
|
def test_sampler_mixed(seed: int):
|
||||||
set_random_seed(seed)
|
set_random_seed(seed)
|
||||||
batch_size = random.randint(1, 256)
|
batch_size = random.randint(1, 256)
|
||||||
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
|
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
|
||||||
|
batch_size)
|
||||||
|
|
||||||
seq_group_metadata_list = []
|
seq_group_metadata_list = []
|
||||||
expected_tokens = []
|
expected_tokens = []
|
||||||
|
prompt_lens = []
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
n = 1
|
n = 1
|
||||||
sampling_type = random.randint(0, 2)
|
sampling_type = random.randint(0, 2)
|
||||||
@ -172,11 +184,13 @@ def test_sampler_mixed(seed: int):
|
|||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
block_tables={0: [1]},
|
block_tables={0: [1]},
|
||||||
))
|
))
|
||||||
|
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||||
|
|
||||||
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
|
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||||
|
prompt_lens)
|
||||||
sampler_output = sampler(embedding=None,
|
sampler_output = sampler(embedding=None,
|
||||||
hidden_states=input_tensor,
|
hidden_states=input_tensor,
|
||||||
input_metadata=input_metadata)
|
sampling_metadata=sampling_metadata)
|
||||||
for i, sequence_output in enumerate(sampler_output):
|
for i, sequence_output in enumerate(sampler_output):
|
||||||
if seq_group_metadata_list[i].sampling_params.use_beam_search:
|
if seq_group_metadata_list[i].sampling_params.use_beam_search:
|
||||||
continue
|
continue
|
||||||
@ -188,7 +202,7 @@ def test_sampler_mixed(seed: int):
|
|||||||
def test_sampler_logits_processors(seed: int):
|
def test_sampler_logits_processors(seed: int):
|
||||||
set_random_seed(seed)
|
set_random_seed(seed)
|
||||||
batch_size = random.randint(1, 256)
|
batch_size = random.randint(1, 256)
|
||||||
input_tensor, _, sampler, worker = _prepare_test(batch_size)
|
input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
|
||||||
|
|
||||||
# This sample logits processor gives infinite score to the i-th token,
|
# This sample logits processor gives infinite score to the i-th token,
|
||||||
# where i is the length of the input sequence.
|
# where i is the length of the input sequence.
|
||||||
@ -198,6 +212,7 @@ def test_sampler_logits_processors(seed: int):
|
|||||||
return logits
|
return logits
|
||||||
|
|
||||||
seq_group_metadata_list = []
|
seq_group_metadata_list = []
|
||||||
|
prompt_lens = []
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
seq_group_metadata_list.append(
|
seq_group_metadata_list.append(
|
||||||
SequenceGroupMetadata(
|
SequenceGroupMetadata(
|
||||||
@ -208,11 +223,13 @@ def test_sampler_logits_processors(seed: int):
|
|||||||
logits_processors=[pick_ith]),
|
logits_processors=[pick_ith]),
|
||||||
block_tables={0: [1]},
|
block_tables={0: [1]},
|
||||||
))
|
))
|
||||||
|
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||||
|
|
||||||
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
|
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||||
|
prompt_lens)
|
||||||
sampler_output = sampler(embedding=None,
|
sampler_output = sampler(embedding=None,
|
||||||
hidden_states=input_tensor,
|
hidden_states=input_tensor,
|
||||||
input_metadata=input_metadata)
|
sampling_metadata=sampling_metadata)
|
||||||
for _, sequence_output in enumerate(sampler_output):
|
for _, sequence_output in enumerate(sampler_output):
|
||||||
for idx, nth_output in enumerate(sequence_output.samples):
|
for idx, nth_output in enumerate(sequence_output.samples):
|
||||||
assert nth_output.output_token == idx
|
assert nth_output.output_token == idx
|
||||||
|
|||||||
@ -25,7 +25,10 @@ class ModelRunner:
|
|||||||
self.parallel_config = parallel_config
|
self.parallel_config = parallel_config
|
||||||
self.scheduler_config = scheduler_config
|
self.scheduler_config = scheduler_config
|
||||||
|
|
||||||
self.sliding_window = model_config.get_sliding_window()
|
# model_config can be None in tests/samplers/test_sampler.py.
|
||||||
|
# FIXME(woosuk): This is a hack to make the tests work. Refactor this.
|
||||||
|
self.sliding_window = (model_config.get_sliding_window()
|
||||||
|
if model_config is not None else None)
|
||||||
self.model = None
|
self.model = None
|
||||||
self.block_size = None # Set after initial profiling.
|
self.block_size = None # Set after initial profiling.
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user