mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 11:26:15 +08:00
316 lines
10 KiB
Python
316 lines
10 KiB
Python
from typing import List, Optional, Tuple, Type, overload
|
|
|
|
import pytest
|
|
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
|
|
BatchEncoding)
|
|
|
|
from vllm.multimodal.utils import rescale_image_size
|
|
from vllm.sequence import SampleLogprobs
|
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
|
|
|
from ..conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
|
|
_ImageAssets)
|
|
from .utils import check_logprobs_close
|
|
|
|
pytestmark = pytest.mark.vlm
|
|
|
|
_LIMIT_IMAGE_PER_PROMPT = 4
|
|
|
|
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
|
"stop_sign":
|
|
"USER: <image>\nWhat's the content of the image?\nASSISTANT:",
|
|
"cherry_blossom":
|
|
"USER: <image>\nWhat is the season?\nASSISTANT:",
|
|
})
|
|
|
|
models = [
|
|
"llava-hf/llava-1.5-7b-hf",
|
|
# TODO: Get this model to produce meaningful output in vLLM
|
|
# "TIGER-Lab/Mantis-8B-siglip-llama3",
|
|
]
|
|
|
|
|
|
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
|
Optional[SampleLogprobs]],
|
|
model: str):
|
|
"""Sanitize vllm output to be comparable with hf output."""
|
|
output_ids, output_str, out_logprobs = vllm_output
|
|
|
|
config = AutoConfig.from_pretrained(model)
|
|
image_token_id = config.image_token_index
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model)
|
|
eos_token_id = tokenizer.eos_token_id
|
|
|
|
hf_output_ids = [
|
|
token_id for idx, token_id in enumerate(output_ids)
|
|
if token_id != image_token_id or output_ids[idx - 1] != image_token_id
|
|
]
|
|
|
|
assert output_str[0] == " "
|
|
hf_output_str = output_str[1:]
|
|
if hf_output_ids[-1] == eos_token_id:
|
|
hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
|
|
|
|
return hf_output_ids, hf_output_str, out_logprobs
|
|
|
|
|
|
@overload
|
|
def run_test(
|
|
hf_runner: Type[HfRunner],
|
|
vllm_runner: Type[VllmRunner],
|
|
image_assets: _ImageAssets,
|
|
model: str,
|
|
*,
|
|
size_factors: List[float],
|
|
dtype: str,
|
|
max_tokens: int,
|
|
num_logprobs: int,
|
|
tensor_parallel_size: int,
|
|
distributed_executor_backend: Optional[str] = None,
|
|
):
|
|
...
|
|
|
|
|
|
@overload
|
|
def run_test(
|
|
hf_runner: Type[HfRunner],
|
|
vllm_runner: Type[VllmRunner],
|
|
image_assets: _ImageAssets,
|
|
model: str,
|
|
*,
|
|
sizes: List[Tuple[int, int]],
|
|
dtype: str,
|
|
max_tokens: int,
|
|
num_logprobs: int,
|
|
tensor_parallel_size: int,
|
|
distributed_executor_backend: Optional[str] = None,
|
|
):
|
|
...
|
|
|
|
|
|
def run_test(
|
|
hf_runner: Type[HfRunner],
|
|
vllm_runner: Type[VllmRunner],
|
|
image_assets: _ImageAssets,
|
|
model: str,
|
|
*,
|
|
size_factors: Optional[List[float]] = None,
|
|
sizes: Optional[List[Tuple[int, int]]] = None,
|
|
dtype: str,
|
|
max_tokens: int,
|
|
num_logprobs: int,
|
|
tensor_parallel_size: int,
|
|
distributed_executor_backend: Optional[str] = None,
|
|
):
|
|
images = [asset.pil_image for asset in image_assets]
|
|
|
|
if size_factors is not None:
|
|
inputs_per_image = [(
|
|
[prompt for _ in size_factors],
|
|
[rescale_image_size(image, factor) for factor in size_factors],
|
|
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
|
elif sizes is not None:
|
|
inputs_per_image = [(
|
|
[prompt for _ in sizes],
|
|
[image.resize(size) for size in sizes],
|
|
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
|
else:
|
|
raise ValueError("You must provide either `size_factors` or `sizes`")
|
|
|
|
_run_test(hf_runner,
|
|
vllm_runner,
|
|
inputs_per_image,
|
|
model,
|
|
dtype=dtype,
|
|
max_tokens=max_tokens,
|
|
num_logprobs=num_logprobs,
|
|
tensor_parallel_size=tensor_parallel_size,
|
|
distributed_executor_backend=distributed_executor_backend)
|
|
|
|
|
|
def _run_test(
|
|
hf_runner: Type[HfRunner],
|
|
vllm_runner: Type[VllmRunner],
|
|
inputs: List[Tuple[List[str], PromptImageInput]],
|
|
model: str,
|
|
*,
|
|
dtype: str,
|
|
max_tokens: int,
|
|
num_logprobs: int,
|
|
tensor_parallel_size: int,
|
|
distributed_executor_backend: Optional[str] = None,
|
|
):
|
|
"""Inference result should be the same between hf and vllm.
|
|
|
|
All the image fixtures for the test is under tests/images.
|
|
For huggingface runner, we provide the PIL images as input.
|
|
For vllm runner, we provide MultiModalDataDict objects
|
|
and corresponding MultiModalConfig as input.
|
|
Note, the text input is also adjusted to abide by vllm contract.
|
|
The text output is sanitized to be able to compare with hf.
|
|
"""
|
|
# NOTE: For local use; this isn't tested in CI yet (see TODO above)
|
|
if model.startswith("TIGER-Lab/Mantis"):
|
|
from mantis.models.mllava import MLlavaProcessor
|
|
|
|
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
|
mantis_processor = MLlavaProcessor.from_pretrained(
|
|
model, torch_dtype=torch_dtype)
|
|
assert isinstance(mantis_processor, MLlavaProcessor)
|
|
else:
|
|
mantis_processor = None
|
|
|
|
# NOTE: take care of the order. run vLLM first, and then run HF.
|
|
# vLLM needs a fresh new process without cuda initialization.
|
|
# if we run HF first, the cuda initialization will be done and it
|
|
# will hurt multiprocessing backend with fork method (the default method).
|
|
|
|
# max_model_len should be greater than image_feature_size
|
|
with vllm_runner(model,
|
|
dtype=dtype,
|
|
max_model_len=4096,
|
|
tensor_parallel_size=tensor_parallel_size,
|
|
distributed_executor_backend=distributed_executor_backend,
|
|
enforce_eager=True,
|
|
limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT
|
|
}) as vllm_model:
|
|
vllm_outputs_per_image = [
|
|
vllm_model.generate_greedy_logprobs(prompts,
|
|
max_tokens,
|
|
num_logprobs=num_logprobs,
|
|
images=images)
|
|
for prompts, images in inputs
|
|
]
|
|
|
|
if mantis_processor is not None:
|
|
|
|
def process(hf_inputs: BatchEncoding):
|
|
hf_inputs["pixel_values"] = hf_inputs["pixel_values"] \
|
|
.to(torch_dtype) # type: ignore
|
|
return hf_inputs
|
|
else:
|
|
|
|
def process(hf_inputs: BatchEncoding):
|
|
return hf_inputs
|
|
|
|
with hf_runner(model,
|
|
dtype=dtype,
|
|
postprocess_inputs=process,
|
|
auto_cls=AutoModelForVision2Seq) as hf_model:
|
|
hf_outputs_per_image = [
|
|
hf_model.generate_greedy_logprobs_limit(prompts,
|
|
max_tokens,
|
|
num_logprobs=num_logprobs,
|
|
images=images)
|
|
for prompts, images in inputs
|
|
]
|
|
|
|
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
|
|
vllm_outputs_per_image):
|
|
# TODO: Check whether using original CLIPVisionModel can improve
|
|
# consistency against HF
|
|
check_logprobs_close(
|
|
outputs_0_lst=hf_outputs,
|
|
outputs_1_lst=[
|
|
vllm_to_hf_output(vllm_output, model)
|
|
for vllm_output in vllm_outputs
|
|
],
|
|
name_0="hf",
|
|
name_1="vllm",
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("model", models)
|
|
@pytest.mark.parametrize(
|
|
"size_factors",
|
|
[
|
|
# No image
|
|
[],
|
|
# Single-scale
|
|
[1.0],
|
|
# Single-scale, batched
|
|
[1.0, 1.0, 1.0],
|
|
# Multi-scale
|
|
[0.25, 0.5, 1.0],
|
|
],
|
|
)
|
|
@pytest.mark.parametrize("dtype", ["half"])
|
|
@pytest.mark.parametrize("max_tokens", [128])
|
|
@pytest.mark.parametrize("num_logprobs", [5])
|
|
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
|
|
dtype: str, max_tokens: int, num_logprobs: int) -> None:
|
|
run_test(
|
|
hf_runner,
|
|
vllm_runner,
|
|
image_assets,
|
|
model,
|
|
size_factors=size_factors,
|
|
dtype=dtype,
|
|
max_tokens=max_tokens,
|
|
num_logprobs=num_logprobs,
|
|
tensor_parallel_size=1,
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("model", models)
|
|
@pytest.mark.parametrize("dtype", ["half"])
|
|
@pytest.mark.parametrize("max_tokens", [128])
|
|
@pytest.mark.parametrize("num_logprobs", [5])
|
|
def test_models_multiple_image_inputs(hf_runner, vllm_runner, image_assets,
|
|
model, dtype, max_tokens,
|
|
num_logprobs) -> None:
|
|
stop_sign = image_assets[0].pil_image
|
|
cherry_blossom = image_assets[1].pil_image
|
|
|
|
inputs = [(
|
|
[
|
|
"USER: <image><image>\nDescribe 2 images.\nASSISTANT:",
|
|
"USER: <image><image>\nDescribe 2 images.\nASSISTANT:",
|
|
"USER: <image><image><image><image>\nDescribe 4 images.\nASSISTANT:", # noqa: E501
|
|
"USER: <image>\nWhat is the season?\nASSISTANT:",
|
|
],
|
|
[
|
|
[stop_sign, cherry_blossom],
|
|
# Images with different sizes and aspect-ratios
|
|
[
|
|
rescale_image_size(stop_sign, 0.1),
|
|
stop_sign,
|
|
],
|
|
[
|
|
stop_sign,
|
|
rescale_image_size(stop_sign, 0.25),
|
|
cherry_blossom.resize((183, 488)),
|
|
cherry_blossom.resize((488, 183))
|
|
],
|
|
cherry_blossom,
|
|
])]
|
|
|
|
_run_test(
|
|
hf_runner,
|
|
vllm_runner,
|
|
inputs,
|
|
model,
|
|
dtype=dtype,
|
|
max_tokens=max_tokens,
|
|
num_logprobs=num_logprobs,
|
|
tensor_parallel_size=1,
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("model", models)
|
|
def test_context_length_too_short(vllm_runner, image_assets, model):
|
|
images = [asset.pil_image for asset in image_assets]
|
|
|
|
with pytest.raises(ValueError, match="too long to fit into the model"):
|
|
vllm_model = vllm_runner(
|
|
model,
|
|
max_model_len=128, # LLaVA has a feature size of 576
|
|
enforce_eager=True,
|
|
)
|
|
|
|
with vllm_model:
|
|
vllm_model.generate_greedy([HF_IMAGE_PROMPTS[0]],
|
|
max_tokens=1,
|
|
images=[images[0]])
|