mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 03:43:06 +08:00
Fix broken worker test (#1900)
This commit is contained in:
parent
9b294976a2
commit
cd3aa153a4
@ -2,18 +2,19 @@ import random
|
||||
import torch
|
||||
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.worker.worker import Worker
|
||||
from vllm.worker.model_runner import ModelRunner
|
||||
|
||||
|
||||
def test_worker_prepare_inputs_for_prompt():
|
||||
worker = Worker(None, None, None)
|
||||
worker.block_size = 16
|
||||
def test_prepare_prompt():
|
||||
model_runner = ModelRunner(None, None, None)
|
||||
model_runner.set_block_size(16)
|
||||
|
||||
batch_size = random.randint(1, 256)
|
||||
prompt_lens = []
|
||||
seq_group_metadata_list = []
|
||||
for i in range(batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
prompt_len = i % (worker.block_size - 1) + 1
|
||||
prompt_len = i % (model_runner.block_size - 1) + 1
|
||||
prompt_lens.append(prompt_len)
|
||||
seq_data = list(range(prompt_len))
|
||||
seq_group_metadata_list.append(
|
||||
@ -24,6 +25,7 @@ def test_worker_prepare_inputs_for_prompt():
|
||||
sampling_params=SamplingParams(temperature=0),
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
|
||||
expected_selected_token_indices = []
|
||||
selected_token_start_idx = 0
|
||||
max_seq_len = max(prompt_lens)
|
||||
@ -31,12 +33,15 @@ def test_worker_prepare_inputs_for_prompt():
|
||||
expected_selected_token_indices.append(selected_token_start_idx +
|
||||
prompt_len - 1)
|
||||
selected_token_start_idx += max_seq_len
|
||||
input_tokens, input_positions, input_metadata = worker._prepare_inputs(
|
||||
input_tokens, input_positions, _ = model_runner._prepare_prompt(
|
||||
seq_group_metadata_list)
|
||||
assert input_tokens.shape == input_positions.shape == (batch_size,
|
||||
max_seq_len)
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens)
|
||||
assert input_tokens.shape == (batch_size, max_seq_len)
|
||||
assert input_positions.shape == (batch_size, max_seq_len)
|
||||
torch.testing.assert_close(input_tokens, input_positions)
|
||||
actual = input_metadata.selected_token_indices
|
||||
|
||||
actual = sampling_metadata.selected_token_indices
|
||||
expected = torch.tensor(expected_selected_token_indices,
|
||||
device=actual.device,
|
||||
dtype=actual.dtype)
|
||||
Loading…
x
Reference in New Issue
Block a user