mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:04:53 +08:00
[CI/Build] Update vision tests (#5307)
This commit is contained in:
parent
7b0a0dfb22
commit
89c920785f
@ -93,14 +93,13 @@ steps:
|
||||
- label: Models Test
|
||||
#mirror_hardwares: [amd]
|
||||
commands:
|
||||
- bash ../.buildkite/download-images.sh
|
||||
- pytest -v -s models --ignore=models/test_llava.py
|
||||
- pytest -v -s models -m \"not llava\"
|
||||
|
||||
- label: Llava Test
|
||||
mirror_hardwares: [amd]
|
||||
commands:
|
||||
- bash ../.buildkite/download-images.sh
|
||||
- pytest -v -s models/test_llava.py
|
||||
- pytest -v -s models -m llava
|
||||
|
||||
- label: Prefix Caching Test
|
||||
mirror_hardwares: [amd]
|
||||
|
||||
@ -71,4 +71,5 @@ markers = [
|
||||
"skip_global_cleanup",
|
||||
"llm: run tests for vLLM API only",
|
||||
"openai: run tests for OpenAI API only",
|
||||
"llava: run tests for LLaVA models only",
|
||||
]
|
||||
|
||||
@ -29,24 +29,19 @@ _LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
|
||||
|
||||
# Multi modal related
|
||||
# You can use `.buildkite/download-images.sh` to download the assets
|
||||
_PIXEL_VALUES_FILES = [
|
||||
PIXEL_VALUES_FILES = [
|
||||
os.path.join(_TEST_DIR, "images", filename) for filename in
|
||||
["stop_sign_pixel_values.pt", "cherry_blossom_pixel_values.pt"]
|
||||
]
|
||||
_IMAGE_FEATURES_FILES = [
|
||||
IMAGE_FEATURES_FILES = [
|
||||
os.path.join(_TEST_DIR, "images", filename) for filename in
|
||||
["stop_sign_image_features.pt", "cherry_blossom_image_features.pt"]
|
||||
]
|
||||
_IMAGE_FILES = [
|
||||
IMAGE_FILES = [
|
||||
os.path.join(_TEST_DIR, "images", filename)
|
||||
for filename in ["stop_sign.jpg", "cherry_blossom.jpg"]
|
||||
]
|
||||
_IMAGE_PROMPTS = [
|
||||
"<image>\nUSER: What's the content of the image?\nASSISTANT:",
|
||||
"<image>\nUSER: What is the season?\nASSISTANT:"
|
||||
]
|
||||
assert len(_PIXEL_VALUES_FILES) == len(_IMAGE_FEATURES_FILES) == len(
|
||||
_IMAGE_FILES) == len(_IMAGE_PROMPTS)
|
||||
assert len(PIXEL_VALUES_FILES) == len(IMAGE_FEATURES_FILES) == len(IMAGE_FILES)
|
||||
|
||||
|
||||
def _read_prompts(filename: str) -> List[str]:
|
||||
@ -84,14 +79,9 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool):
|
||||
cleanup()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def hf_image_prompts() -> List[str]:
|
||||
return _IMAGE_PROMPTS
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def hf_images() -> List[Image.Image]:
|
||||
return [Image.open(filename) for filename in _IMAGE_FILES]
|
||||
return [Image.open(filename) for filename in IMAGE_FILES]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@ -101,26 +91,17 @@ def vllm_images(request) -> List[MultiModalData]:
|
||||
VisionLanguageConfig.ImageInputType.IMAGE_FEATURES):
|
||||
return [
|
||||
ImageFeatureData(torch.load(filename))
|
||||
for filename in _IMAGE_FEATURES_FILES
|
||||
for filename in IMAGE_FEATURES_FILES
|
||||
]
|
||||
else:
|
||||
return [
|
||||
ImagePixelData(Image.open(filename)) for filename in _IMAGE_FILES
|
||||
ImagePixelData(Image.open(filename)) for filename in IMAGE_FILES
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def vllm_image_tensors(request) -> List[torch.Tensor]:
|
||||
return [torch.load(filename) for filename in _PIXEL_VALUES_FILES]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def vllm_image_prompts(request) -> List[str]:
|
||||
vision_language_config = request.getfixturevalue("model_and_config")[1]
|
||||
return [
|
||||
"<image>" * (vision_language_config.image_feature_size - 1) + p
|
||||
for p in _IMAGE_PROMPTS
|
||||
]
|
||||
return [torch.load(filename) for filename in PIXEL_VALUES_FILES]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@ -1,14 +1,22 @@
|
||||
import gc
|
||||
from dataclasses import fields
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.config import VisionLanguageConfig
|
||||
|
||||
from ..conftest import IMAGE_FILES
|
||||
|
||||
pytestmark = pytest.mark.llava
|
||||
|
||||
# The image token is placed before "user" on purpose so that the test can pass
|
||||
HF_IMAGE_PROMPTS = [
|
||||
"<image>\nUSER: What's the content of the image?\nASSISTANT:",
|
||||
"<image>\nUSER: What is the season?\nASSISTANT:",
|
||||
]
|
||||
|
||||
assert len(HF_IMAGE_PROMPTS) == len(IMAGE_FILES)
|
||||
|
||||
|
||||
def iter_llava_configs(model_name: str):
|
||||
image_hw_to_feature_size = {
|
||||
@ -36,53 +44,35 @@ model_and_vl_config = [
|
||||
]
|
||||
|
||||
|
||||
def as_dict(vlm_config: VisionLanguageConfig) -> Dict[str, Any]:
|
||||
"""Flatten vision language config to pure args.
|
||||
|
||||
Compatible with what llm entrypoint expects.
|
||||
"""
|
||||
result = {}
|
||||
for field in fields(vlm_config):
|
||||
value = getattr(vlm_config, field.name)
|
||||
if isinstance(value, Enum):
|
||||
result[field.name] = value.name.lower()
|
||||
elif isinstance(value, tuple):
|
||||
result[field.name] = ",".join([str(item) for item in value])
|
||||
else:
|
||||
result[field.name] = value
|
||||
|
||||
result["disable_image_processor"] = vlm_config.image_processor is None
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def sanitize_vllm_output(vllm_output: Tuple[List[int], str],
|
||||
vision_language_config: VisionLanguageConfig,
|
||||
model_id: str):
|
||||
def vllm_to_hf_output(vllm_output: Tuple[List[int], str],
|
||||
vlm_config: VisionLanguageConfig, model_id: str):
|
||||
"""Sanitize vllm output to be comparable with hf output.
|
||||
The function reduces `input_ids` from 1, 32000, 32000, ..., 32000,
|
||||
x1, x2, x3 ... to 1, 32000, x1, x2, x3 ...
|
||||
It also reduces `output_str` from "<image><image>bla" to "bla".
|
||||
"""
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
image_token_str = tokenizer.decode(vision_language_config.image_token_id)
|
||||
image_token_str_len = len(image_token_str)
|
||||
input_ids, output_str = vllm_output
|
||||
sanitized_input_ids = input_ids[0:2] + input_ids[2 + vision_language_config
|
||||
.image_feature_size - 1:]
|
||||
sanitzied_output_str = output_str[vision_language_config.
|
||||
image_feature_size *
|
||||
image_token_str_len:]
|
||||
return sanitized_input_ids, sanitzied_output_str
|
||||
image_token_id = vlm_config.image_token_id
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
image_token_str = tokenizer.decode(image_token_id)
|
||||
|
||||
hf_input_ids = [
|
||||
input_id for idx, input_id in enumerate(input_ids)
|
||||
if input_id != image_token_id or input_ids[idx - 1] != image_token_id
|
||||
]
|
||||
hf_output_str = output_str \
|
||||
.replace(image_token_str * vlm_config.image_feature_size, "")
|
||||
|
||||
return hf_input_ids, hf_output_str
|
||||
|
||||
|
||||
@pytest.mark.parametrize("worker_use_ray", [False])
|
||||
# TODO: Add test for `tensor_parallel_size` [ref: PR #3883]
|
||||
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
def test_models(hf_runner, vllm_runner, hf_image_prompts, hf_images,
|
||||
vllm_image_prompts, vllm_images, model_and_config, dtype: str,
|
||||
max_tokens: int, worker_use_ray: bool) -> None:
|
||||
def test_models(hf_runner, vllm_runner, hf_images, vllm_images,
|
||||
model_and_config, dtype: str, max_tokens: int) -> None:
|
||||
"""Inference result should be the same between hf and vllm.
|
||||
|
||||
All the image fixtures for the test is under tests/images.
|
||||
@ -92,36 +82,33 @@ def test_models(hf_runner, vllm_runner, hf_image_prompts, hf_images,
|
||||
Note, the text input is also adjusted to abide by vllm contract.
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
"""
|
||||
model_id, vision_language_config = model_and_config
|
||||
model_id, vlm_config = model_and_config
|
||||
|
||||
hf_model = hf_runner(model_id, dtype=dtype, is_vision_model=True)
|
||||
hf_outputs = hf_model.generate_greedy(hf_image_prompts,
|
||||
hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS,
|
||||
max_tokens,
|
||||
images=hf_images)
|
||||
del hf_model
|
||||
|
||||
vllm_image_prompts = [
|
||||
p.replace("<image>", "<image>" * vlm_config.image_feature_size)
|
||||
for p in HF_IMAGE_PROMPTS
|
||||
]
|
||||
|
||||
vllm_model = vllm_runner(model_id,
|
||||
dtype=dtype,
|
||||
worker_use_ray=worker_use_ray,
|
||||
enforce_eager=True,
|
||||
**as_dict(vision_language_config))
|
||||
**vlm_config.as_cli_args_dict())
|
||||
vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts,
|
||||
max_tokens,
|
||||
images=vllm_images)
|
||||
del vllm_model
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
for i in range(len(hf_image_prompts)):
|
||||
for i in range(len(HF_IMAGE_PROMPTS)):
|
||||
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||
vllm_output_ids, vllm_output_str = sanitize_vllm_output(
|
||||
vllm_outputs[i], vision_language_config, model_id)
|
||||
vllm_output_ids, vllm_output_str = vllm_to_hf_output(
|
||||
vllm_outputs[i], vlm_config, model_id)
|
||||
assert hf_output_str == vllm_output_str, (
|
||||
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
||||
assert hf_output_ids == vllm_output_ids, (
|
||||
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
||||
|
||||
|
||||
# TODO: Add test for `tensor_parallel_size` [ref: PR #3883]
|
||||
# (Requires multiple GPUs)
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
import enum
|
||||
import json
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Union
|
||||
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Tuple,
|
||||
Union)
|
||||
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
@ -1114,6 +1115,25 @@ class VisionLanguageConfig:
|
||||
f"Expecting to choose from "
|
||||
f"{[x.name for x in cls.ImageInputType]}.") from e
|
||||
|
||||
def as_cli_args_dict(self) -> Dict[str, Any]:
|
||||
"""Flatten vision language config to pure args.
|
||||
|
||||
Compatible with what llm entrypoint expects.
|
||||
"""
|
||||
result: Dict[str, Any] = {}
|
||||
for f in fields(self):
|
||||
value = getattr(self, f.name)
|
||||
if isinstance(value, enum.Enum):
|
||||
result[f.name] = value.name.lower()
|
||||
elif isinstance(value, tuple):
|
||||
result[f.name] = ",".join([str(item) for item in value])
|
||||
else:
|
||||
result[f.name] = value
|
||||
|
||||
result["disable_image_processor"] = self.image_processor is None
|
||||
|
||||
return result
|
||||
|
||||
|
||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"half": torch.float16,
|
||||
|
||||
@ -75,6 +75,14 @@ class ImagePixelData(MultiModalData):
|
||||
|
||||
self.image = image
|
||||
|
||||
def __repr__(self) -> str:
|
||||
image = self.image
|
||||
if isinstance(image, Image.Image):
|
||||
return f"{type(self).__name__}(image={image})"
|
||||
|
||||
return (f"{type(self).__name__}(image=torch.Tensor(shape="
|
||||
f"{image.shape}, dtype={image.dtype}))")
|
||||
|
||||
|
||||
class ImagePixelPlugin(MultiModalPlugin[ImagePixelData]):
|
||||
|
||||
@ -96,10 +104,10 @@ class ImagePixelPlugin(MultiModalPlugin[ImagePixelData]):
|
||||
self, data: ImagePixelData, model_config: ModelConfig,
|
||||
vlm_config: VisionLanguageConfig) -> Dict[str, torch.Tensor]:
|
||||
image = data.image
|
||||
image_processor = self._get_hf_image_processor(model_config,
|
||||
vlm_config)
|
||||
|
||||
if isinstance(image, Image.Image):
|
||||
image_processor = self._get_hf_image_processor(
|
||||
model_config, vlm_config)
|
||||
if image_processor is None:
|
||||
raise RuntimeError("No HuggingFace processor is available"
|
||||
"to process the image object")
|
||||
@ -127,6 +135,12 @@ class ImageFeatureData(MultiModalData):
|
||||
def __init__(self, image_features: torch.Tensor) -> None:
|
||||
self.image_features = image_features
|
||||
|
||||
def __repr__(self) -> str:
|
||||
image_features = self.image_features
|
||||
|
||||
return (f"{type(self).__name__}(image_features=torch.Tensor(shape="
|
||||
f"{image_features.shape}, dtype={image_features.dtype}))")
|
||||
|
||||
|
||||
class ImageFeaturePlugin(MultiModalPlugin[ImageFeatureData]):
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user