diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 1df75cd7be14..3ad2d4608fbd 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -8,7 +8,7 @@ import torch from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.utils import set_random_seed from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata -from vllm.worker.worker import Worker +from vllm.worker.model_runner import ModelRunner class MockLogitsSampler(Sampler): @@ -27,7 +27,7 @@ class MockLogitsSampler(Sampler): def _prepare_test( batch_size: int -) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, Worker]: +) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, ModelRunner]: vocab_size = 32000 input_tensor = torch.rand((batch_size, 1024), device="cuda", @@ -37,9 +37,8 @@ def _prepare_test( device=input_tensor.device, dtype=input_tensor.dtype) sampler = MockLogitsSampler(32000, fake_logits) - worker = Worker(None, None, None) - worker.block_size = 16 - return input_tensor, fake_logits, sampler, worker + model_runner = ModelRunner(None, None, None) + return input_tensor, fake_logits, sampler, model_runner RANDOM_SEEDS = list(range(128)) @@ -49,9 +48,11 @@ RANDOM_SEEDS = list(range(128)) def test_sampler_all_greedy(seed: int): set_random_seed(seed) batch_size = random.randint(1, 256) - input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size) + input_tensor, fake_logits, sampler, model_runner = _prepare_test( + batch_size) seq_group_metadata_list = [] + prompt_lens = [] for i in range(batch_size): seq_group_metadata_list.append( SequenceGroupMetadata( @@ -61,11 +62,13 @@ def test_sampler_all_greedy(seed: int): sampling_params=SamplingParams(temperature=0, ), block_tables={0: [1]}, )) + prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - _, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list) + sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, + prompt_lens) sampler_output = sampler(embedding=None, hidden_states=input_tensor, - input_metadata=input_metadata) + sampling_metadata=sampling_metadata) expected = torch.argmax(fake_logits, dim=-1) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: @@ -76,12 +79,14 @@ def test_sampler_all_greedy(seed: int): def test_sampler_all_random(seed: int): set_random_seed(seed) batch_size = random.randint(1, 256) - input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size) + input_tensor, fake_logits, sampler, model_runner = _prepare_test( + batch_size) for i in range(batch_size): fake_logits[i, i] = 1e2 seq_group_metadata_list = [] + prompt_lens = [] for i in range(batch_size): seq_group_metadata_list.append( SequenceGroupMetadata( @@ -94,11 +99,13 @@ def test_sampler_all_random(seed: int): ), block_tables={0: [1]}, )) + prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - _, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list) + sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, + prompt_lens) sampler_output = sampler(embedding=None, hidden_states=input_tensor, - input_metadata=input_metadata) + sampling_metadata=sampling_metadata) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: assert nth_output.output_token == i @@ -108,9 +115,10 @@ def test_sampler_all_random(seed: int): def test_sampler_all_beam(seed: int): set_random_seed(seed) batch_size = random.randint(1, 256) - input_tensor, _, sampler, worker = _prepare_test(batch_size) + input_tensor, _, sampler, model_runner = _prepare_test(batch_size) seq_group_metadata_list = [] + prompt_lens = [] for i in range(batch_size): seq_group_metadata_list.append( SequenceGroupMetadata( @@ -124,11 +132,13 @@ def test_sampler_all_beam(seed: int): ), block_tables={0: [1]}, )) + prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - _, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list) + sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, + prompt_lens) sampler(embedding=None, hidden_states=input_tensor, - input_metadata=input_metadata) + sampling_metadata=sampling_metadata) # no assertion here as I am not sure how to determine whether # the outputs are expected - in other words, this just tests # whether there are no exceptions in the sampler @@ -139,10 +149,12 @@ def test_sampler_all_beam(seed: int): def test_sampler_mixed(seed: int): set_random_seed(seed) batch_size = random.randint(1, 256) - input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size) + input_tensor, fake_logits, sampler, model_runner = _prepare_test( + batch_size) seq_group_metadata_list = [] expected_tokens = [] + prompt_lens = [] for i in range(batch_size): n = 1 sampling_type = random.randint(0, 2) @@ -172,11 +184,13 @@ def test_sampler_mixed(seed: int): sampling_params=sampling_params, block_tables={0: [1]}, )) + prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - _, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list) + sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, + prompt_lens) sampler_output = sampler(embedding=None, hidden_states=input_tensor, - input_metadata=input_metadata) + sampling_metadata=sampling_metadata) for i, sequence_output in enumerate(sampler_output): if seq_group_metadata_list[i].sampling_params.use_beam_search: continue @@ -188,7 +202,7 @@ def test_sampler_mixed(seed: int): def test_sampler_logits_processors(seed: int): set_random_seed(seed) batch_size = random.randint(1, 256) - input_tensor, _, sampler, worker = _prepare_test(batch_size) + input_tensor, _, sampler, model_runner = _prepare_test(batch_size) # This sample logits processor gives infinite score to the i-th token, # where i is the length of the input sequence. @@ -198,6 +212,7 @@ def test_sampler_logits_processors(seed: int): return logits seq_group_metadata_list = [] + prompt_lens = [] for i in range(batch_size): seq_group_metadata_list.append( SequenceGroupMetadata( @@ -208,11 +223,13 @@ def test_sampler_logits_processors(seed: int): logits_processors=[pick_ith]), block_tables={0: [1]}, )) + prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - _, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list) + sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, + prompt_lens) sampler_output = sampler(embedding=None, hidden_states=input_tensor, - input_metadata=input_metadata) + sampling_metadata=sampling_metadata) for _, sequence_output in enumerate(sampler_output): for idx, nth_output in enumerate(sequence_output.samples): assert nth_output.output_token == idx diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index e0e381b369e4..2209c994e2b8 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -25,7 +25,10 @@ class ModelRunner: self.parallel_config = parallel_config self.scheduler_config = scheduler_config - self.sliding_window = model_config.get_sliding_window() + # model_config can be None in tests/samplers/test_sampler.py. + # FIXME(woosuk): This is a hack to make the tests work. Refactor this. + self.sliding_window = (model_config.get_sliding_window() + if model_config is not None else None) self.model = None self.block_size = None # Set after initial profiling.