mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-09 18:19:12 +08:00
[Tests] conftest: Extending VllmRunner and HfRunner to accept token_ids as input (#26295)
Signed-off-by: Yannick Schnider <yannick.schnider1@ibm.com> Signed-off-by: Yannick Schnider <Yannick.Schnider1@ibm.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
4727a8afa7
commit
6431be808f
@ -57,7 +57,7 @@ from vllm.multimodal.utils import fetch_image
|
|||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.sampling_params import BeamSearchParams
|
from vllm.sampling_params import BeamSearchParams
|
||||||
from vllm.transformers_utils.utils import maybe_model_redirect
|
from vllm.transformers_utils.utils import maybe_model_redirect
|
||||||
from vllm.utils import set_default_torch_num_threads
|
from vllm.utils import is_list_of, set_default_torch_num_threads
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -406,11 +406,11 @@ class HfRunner:
|
|||||||
|
|
||||||
def get_inputs(
|
def get_inputs(
|
||||||
self,
|
self,
|
||||||
prompts: list[str],
|
prompts: Union[list[str], list[list[int]]],
|
||||||
images: Optional[PromptImageInput] = None,
|
images: Optional[PromptImageInput] = None,
|
||||||
videos: Optional[PromptVideoInput] = None,
|
videos: Optional[PromptVideoInput] = None,
|
||||||
audios: Optional[PromptAudioInput] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
) -> list[Union[BatchFeature, BatchEncoding]]:
|
) -> list[Union[BatchFeature, BatchEncoding, dict[str, torch.Tensor]]]:
|
||||||
if images is not None:
|
if images is not None:
|
||||||
assert len(prompts) == len(images)
|
assert len(prompts) == len(images)
|
||||||
|
|
||||||
@ -420,31 +420,48 @@ class HfRunner:
|
|||||||
if audios is not None:
|
if audios is not None:
|
||||||
assert len(prompts) == len(audios)
|
assert len(prompts) == len(audios)
|
||||||
|
|
||||||
all_inputs: list[Union[BatchFeature, BatchEncoding]] = []
|
all_inputs: list[
|
||||||
|
Union[BatchFeature, BatchEncoding, dict[str, torch.Tensor]]
|
||||||
|
] = []
|
||||||
for i, prompt in enumerate(prompts):
|
for i, prompt in enumerate(prompts):
|
||||||
processor_kwargs: dict[str, Any] = {
|
if isinstance(prompt, str):
|
||||||
"text": prompt,
|
processor_kwargs: dict[str, Any] = {
|
||||||
"return_tensors": "pt",
|
"text": prompt,
|
||||||
}
|
"return_tensors": "pt",
|
||||||
if images is not None and (image := images[i]) is not None:
|
}
|
||||||
processor_kwargs["images"] = image
|
if images is not None and (image := images[i]) is not None:
|
||||||
if videos is not None and (video := videos[i]) is not None:
|
processor_kwargs["images"] = image
|
||||||
processor_kwargs["videos"] = video
|
if videos is not None and (video := videos[i]) is not None:
|
||||||
if audios is not None and (audio_inputs := audios[i]) is not None:
|
processor_kwargs["videos"] = video
|
||||||
# HACK - not all processors take sampling_rate; we should
|
if audios is not None and (audio_inputs := audios[i]) is not None:
|
||||||
# clean this up in the future.
|
# HACK - not all processors take sampling_rate; we should
|
||||||
if len(audio_inputs) == 2:
|
# clean this up in the future.
|
||||||
audio, sr = audio_inputs
|
if len(audio_inputs) == 2:
|
||||||
processor_kwargs["audio"] = audio
|
audio, sr = audio_inputs
|
||||||
processor_kwargs["sampling_rate"] = sr
|
processor_kwargs["audio"] = audio
|
||||||
else:
|
processor_kwargs["sampling_rate"] = sr
|
||||||
processor_kwargs["audio"] = audio_inputs
|
else:
|
||||||
|
processor_kwargs["audio"] = audio_inputs
|
||||||
|
|
||||||
inputs = self.processor(**processor_kwargs)
|
inputs = self.processor(**processor_kwargs)
|
||||||
if isinstance(inputs, BatchFeature):
|
if isinstance(inputs, BatchFeature):
|
||||||
inputs = inputs.to(dtype=self.dtype)
|
inputs = inputs.to(dtype=self.dtype)
|
||||||
|
all_inputs.append(inputs)
|
||||||
all_inputs.append(inputs)
|
else:
|
||||||
|
# check that prompt is (batched) list of integers (token ids)
|
||||||
|
if not is_list_of(prompt, typ=int, check="all"):
|
||||||
|
raise ValueError(
|
||||||
|
"Prompt must be a list of ints corresponding to the prompt token ids."
|
||||||
|
)
|
||||||
|
# check that no multimodal input is provided
|
||||||
|
if images or videos or audios:
|
||||||
|
raise ValueError(
|
||||||
|
"When providing prompt token ids multimodal inputs are not supported."
|
||||||
|
)
|
||||||
|
input_dict = {
|
||||||
|
"input_ids": torch.tensor(prompt, dtype=torch.long).unsqueeze(0),
|
||||||
|
}
|
||||||
|
all_inputs.append(input_dict)
|
||||||
|
|
||||||
return all_inputs
|
return all_inputs
|
||||||
|
|
||||||
@ -477,7 +494,7 @@ class HfRunner:
|
|||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
prompts: list[str],
|
prompts: Union[list[str], list[list[int]]],
|
||||||
images: Optional[PromptImageInput] = None,
|
images: Optional[PromptImageInput] = None,
|
||||||
videos: Optional[PromptVideoInput] = None,
|
videos: Optional[PromptVideoInput] = None,
|
||||||
audios: Optional[PromptAudioInput] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
@ -505,7 +522,7 @@ class HfRunner:
|
|||||||
|
|
||||||
def generate_greedy(
|
def generate_greedy(
|
||||||
self,
|
self,
|
||||||
prompts: list[str],
|
prompts: Union[list[str], list[list[int]]],
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
images: Optional[PromptImageInput] = None,
|
images: Optional[PromptImageInput] = None,
|
||||||
videos: Optional[PromptVideoInput] = None,
|
videos: Optional[PromptVideoInput] = None,
|
||||||
@ -807,7 +824,7 @@ class VllmRunner:
|
|||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
prompts: Union[list[str], list[torch.Tensor]],
|
prompts: Union[list[str], list[torch.Tensor], list[list[int]]],
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
images: Optional[PromptImageInput] = None,
|
images: Optional[PromptImageInput] = None,
|
||||||
videos: Optional[PromptVideoInput] = None,
|
videos: Optional[PromptVideoInput] = None,
|
||||||
@ -877,7 +894,7 @@ class VllmRunner:
|
|||||||
|
|
||||||
def generate_greedy(
|
def generate_greedy(
|
||||||
self,
|
self,
|
||||||
prompts: Union[list[str], list[torch.Tensor]],
|
prompts: Union[list[str], list[torch.Tensor], list[list[int]]],
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
images: Optional[PromptImageInput] = None,
|
images: Optional[PromptImageInput] = None,
|
||||||
videos: Optional[PromptVideoInput] = None,
|
videos: Optional[PromptVideoInput] = None,
|
||||||
|
|||||||
@ -23,13 +23,10 @@ the 1st will be sampled after the prefill and the 2nd after the first decode
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
|
||||||
from transformers import AutoModelForCausalLM
|
|
||||||
|
|
||||||
|
from tests.conftest import HfRunner, VllmRunner
|
||||||
from tests.models.utils import check_outputs_equal
|
from tests.models.utils import check_outputs_equal
|
||||||
from tests.utils import create_new_process_for_each_test
|
from tests.utils import create_new_process_for_each_test
|
||||||
from vllm import LLM, SamplingParams
|
|
||||||
from vllm.inputs import TokensPrompt
|
|
||||||
|
|
||||||
|
|
||||||
@create_new_process_for_each_test()
|
@create_new_process_for_each_test()
|
||||||
@ -43,6 +40,8 @@ from vllm.inputs import TokensPrompt
|
|||||||
)
|
)
|
||||||
def test_max_context_length(
|
def test_max_context_length(
|
||||||
model: str,
|
model: str,
|
||||||
|
vllm_runner: type[VllmRunner],
|
||||||
|
hf_runner: type[HfRunner],
|
||||||
prompt_len: int,
|
prompt_len: int,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -57,42 +56,26 @@ def test_max_context_length(
|
|||||||
# Construct a prompt of size prompt_len
|
# Construct a prompt of size prompt_len
|
||||||
prompt_ids = [[43] * prompt_len]
|
prompt_ids = [[43] * prompt_len]
|
||||||
|
|
||||||
# Generate max_tokens new tokens deterministically.
|
|
||||||
sampling_params = [
|
|
||||||
SamplingParams(max_tokens=max_tokens, temperature=0.0, ignore_eos=True)
|
|
||||||
]
|
|
||||||
|
|
||||||
# --- vLLM generation ---
|
# --- vLLM generation ---
|
||||||
llm = LLM(
|
with vllm_runner(
|
||||||
model=model,
|
model_name=model,
|
||||||
tokenizer=model,
|
tokenizer_name=model,
|
||||||
max_model_len=2048,
|
max_model_len=2048,
|
||||||
max_num_seqs=1,
|
max_num_seqs=1,
|
||||||
tensor_parallel_size=1,
|
tensor_parallel_size=1,
|
||||||
)
|
) as vllm_model:
|
||||||
|
# Generate max_tokens new tokens deterministically.
|
||||||
vllm_token_prompts = [TokensPrompt(prompt_token_ids=prompt_ids[0])]
|
vllm_outputs = vllm_model.generate_greedy(prompt_ids, max_tokens)
|
||||||
vllm_results = llm.generate(vllm_token_prompts, sampling_params)
|
|
||||||
|
|
||||||
vllm_output_ids = vllm_results[0].outputs[0].token_ids
|
|
||||||
|
|
||||||
# --- HuggingFace generation ---
|
# --- HuggingFace generation ---
|
||||||
with torch.no_grad():
|
with hf_runner(
|
||||||
hf_model = AutoModelForCausalLM.from_pretrained(model)
|
model_name=model,
|
||||||
|
) as hf_model:
|
||||||
|
hf_outputs = hf_model.generate_greedy(prompt_ids, max_tokens)
|
||||||
|
|
||||||
# HF expects a tensor of input ids shaped (batch, seq_len).
|
# vLLM and HF runners return prompt + generated tokens. Slice off the prompt.
|
||||||
hf_input_tokens = torch.tensor(prompt_ids[0]).unsqueeze(0)
|
vllm_output_ids = vllm_outputs[0][0][prompt_len:]
|
||||||
|
hf_output_ids = hf_outputs[0][0][prompt_len:]
|
||||||
# Generate max_tokens new tokens deterministically.
|
|
||||||
hf_generated = hf_model.generate(
|
|
||||||
hf_input_tokens,
|
|
||||||
do_sample=False,
|
|
||||||
min_new_tokens=max_tokens,
|
|
||||||
max_new_tokens=max_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
# HF returns the prompt + generated tokens. Slice off the prompt.
|
|
||||||
hf_output_ids = hf_generated.cpu().tolist()[0][len(prompt_ids[0]) :]
|
|
||||||
|
|
||||||
# check that exactly max_tokens tokens were generated with vLLM and HF
|
# check that exactly max_tokens tokens were generated with vLLM and HF
|
||||||
assert len(vllm_output_ids) == len(hf_output_ids) == max_tokens
|
assert len(vllm_output_ids) == len(hf_output_ids) == max_tokens
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user