[CI/Build] Update vision tests (#5307)

This commit is contained in:
Cyrus Leung 2024-06-06 18:17:18 +08:00 committed by GitHub
parent 7b0a0dfb22
commit 89c920785f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 90 additions and 88 deletions

View File

@ -93,14 +93,13 @@ steps:
- label: Models Test
#mirror_hardwares: [amd]
commands:
- bash ../.buildkite/download-images.sh
- pytest -v -s models --ignore=models/test_llava.py
- pytest -v -s models -m \"not llava\"
- label: Llava Test
mirror_hardwares: [amd]
commands:
- bash ../.buildkite/download-images.sh
- pytest -v -s models/test_llava.py
- pytest -v -s models -m llava
- label: Prefix Caching Test
mirror_hardwares: [amd]

View File

@ -71,4 +71,5 @@ markers = [
"skip_global_cleanup",
"llm: run tests for vLLM API only",
"openai: run tests for OpenAI API only",
"llava: run tests for LLaVA models only",
]

View File

@ -29,24 +29,19 @@ _LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
# Multi modal related
# You can use `.buildkite/download-images.sh` to download the assets
_PIXEL_VALUES_FILES = [
PIXEL_VALUES_FILES = [
os.path.join(_TEST_DIR, "images", filename) for filename in
["stop_sign_pixel_values.pt", "cherry_blossom_pixel_values.pt"]
]
_IMAGE_FEATURES_FILES = [
IMAGE_FEATURES_FILES = [
os.path.join(_TEST_DIR, "images", filename) for filename in
["stop_sign_image_features.pt", "cherry_blossom_image_features.pt"]
]
_IMAGE_FILES = [
IMAGE_FILES = [
os.path.join(_TEST_DIR, "images", filename)
for filename in ["stop_sign.jpg", "cherry_blossom.jpg"]
]
_IMAGE_PROMPTS = [
"<image>\nUSER: What's the content of the image?\nASSISTANT:",
"<image>\nUSER: What is the season?\nASSISTANT:"
]
assert len(_PIXEL_VALUES_FILES) == len(_IMAGE_FEATURES_FILES) == len(
_IMAGE_FILES) == len(_IMAGE_PROMPTS)
assert len(PIXEL_VALUES_FILES) == len(IMAGE_FEATURES_FILES) == len(IMAGE_FILES)
def _read_prompts(filename: str) -> List[str]:
@ -84,14 +79,9 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool):
cleanup()
@pytest.fixture(scope="session")
def hf_image_prompts() -> List[str]:
return _IMAGE_PROMPTS
@pytest.fixture(scope="session")
def hf_images() -> List[Image.Image]:
return [Image.open(filename) for filename in _IMAGE_FILES]
return [Image.open(filename) for filename in IMAGE_FILES]
@pytest.fixture()
@ -101,26 +91,17 @@ def vllm_images(request) -> List[MultiModalData]:
VisionLanguageConfig.ImageInputType.IMAGE_FEATURES):
return [
ImageFeatureData(torch.load(filename))
for filename in _IMAGE_FEATURES_FILES
for filename in IMAGE_FEATURES_FILES
]
else:
return [
ImagePixelData(Image.open(filename)) for filename in _IMAGE_FILES
ImagePixelData(Image.open(filename)) for filename in IMAGE_FILES
]
@pytest.fixture()
def vllm_image_tensors(request) -> List[torch.Tensor]:
return [torch.load(filename) for filename in _PIXEL_VALUES_FILES]
@pytest.fixture()
def vllm_image_prompts(request) -> List[str]:
vision_language_config = request.getfixturevalue("model_and_config")[1]
return [
"<image>" * (vision_language_config.image_feature_size - 1) + p
for p in _IMAGE_PROMPTS
]
return [torch.load(filename) for filename in PIXEL_VALUES_FILES]
@pytest.fixture

View File

@ -1,14 +1,22 @@
import gc
from dataclasses import fields
from enum import Enum
from typing import Any, Dict, List, Tuple
from typing import List, Tuple
import pytest
import torch
from transformers import AutoTokenizer
from vllm.config import VisionLanguageConfig
from ..conftest import IMAGE_FILES
pytestmark = pytest.mark.llava
# The image token is placed before "user" on purpose so that the test can pass
HF_IMAGE_PROMPTS = [
"<image>\nUSER: What's the content of the image?\nASSISTANT:",
"<image>\nUSER: What is the season?\nASSISTANT:",
]
assert len(HF_IMAGE_PROMPTS) == len(IMAGE_FILES)
def iter_llava_configs(model_name: str):
image_hw_to_feature_size = {
@ -36,53 +44,35 @@ model_and_vl_config = [
]
def as_dict(vlm_config: VisionLanguageConfig) -> Dict[str, Any]:
"""Flatten vision language config to pure args.
Compatible with what llm entrypoint expects.
"""
result = {}
for field in fields(vlm_config):
value = getattr(vlm_config, field.name)
if isinstance(value, Enum):
result[field.name] = value.name.lower()
elif isinstance(value, tuple):
result[field.name] = ",".join([str(item) for item in value])
else:
result[field.name] = value
result["disable_image_processor"] = vlm_config.image_processor is None
return result
def sanitize_vllm_output(vllm_output: Tuple[List[int], str],
vision_language_config: VisionLanguageConfig,
model_id: str):
def vllm_to_hf_output(vllm_output: Tuple[List[int], str],
vlm_config: VisionLanguageConfig, model_id: str):
"""Sanitize vllm output to be comparable with hf output.
The function reduces `input_ids` from 1, 32000, 32000, ..., 32000,
x1, x2, x3 ... to 1, 32000, x1, x2, x3 ...
It also reduces `output_str` from "<image><image>bla" to "bla".
"""
tokenizer = AutoTokenizer.from_pretrained(model_id)
image_token_str = tokenizer.decode(vision_language_config.image_token_id)
image_token_str_len = len(image_token_str)
input_ids, output_str = vllm_output
sanitized_input_ids = input_ids[0:2] + input_ids[2 + vision_language_config
.image_feature_size - 1:]
sanitzied_output_str = output_str[vision_language_config.
image_feature_size *
image_token_str_len:]
return sanitized_input_ids, sanitzied_output_str
image_token_id = vlm_config.image_token_id
tokenizer = AutoTokenizer.from_pretrained(model_id)
image_token_str = tokenizer.decode(image_token_id)
hf_input_ids = [
input_id for idx, input_id in enumerate(input_ids)
if input_id != image_token_id or input_ids[idx - 1] != image_token_id
]
hf_output_str = output_str \
.replace(image_token_str * vlm_config.image_feature_size, "")
return hf_input_ids, hf_output_str
@pytest.mark.parametrize("worker_use_ray", [False])
# TODO: Add test for `tensor_parallel_size` [ref: PR #3883]
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
def test_models(hf_runner, vllm_runner, hf_image_prompts, hf_images,
vllm_image_prompts, vllm_images, model_and_config, dtype: str,
max_tokens: int, worker_use_ray: bool) -> None:
def test_models(hf_runner, vllm_runner, hf_images, vllm_images,
model_and_config, dtype: str, max_tokens: int) -> None:
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images.
@ -92,36 +82,33 @@ def test_models(hf_runner, vllm_runner, hf_image_prompts, hf_images,
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
model_id, vision_language_config = model_and_config
model_id, vlm_config = model_and_config
hf_model = hf_runner(model_id, dtype=dtype, is_vision_model=True)
hf_outputs = hf_model.generate_greedy(hf_image_prompts,
hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS,
max_tokens,
images=hf_images)
del hf_model
vllm_image_prompts = [
p.replace("<image>", "<image>" * vlm_config.image_feature_size)
for p in HF_IMAGE_PROMPTS
]
vllm_model = vllm_runner(model_id,
dtype=dtype,
worker_use_ray=worker_use_ray,
enforce_eager=True,
**as_dict(vision_language_config))
**vlm_config.as_cli_args_dict())
vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts,
max_tokens,
images=vllm_images)
del vllm_model
gc.collect()
torch.cuda.empty_cache()
for i in range(len(hf_image_prompts)):
for i in range(len(HF_IMAGE_PROMPTS)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = sanitize_vllm_output(
vllm_outputs[i], vision_language_config, model_id)
vllm_output_ids, vllm_output_str = vllm_to_hf_output(
vllm_outputs[i], vlm_config, model_id)
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
# TODO: Add test for `tensor_parallel_size` [ref: PR #3883]
# (Requires multiple GPUs)

View File

@ -1,7 +1,8 @@
import enum
import json
from dataclasses import dataclass, field, fields
from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Union
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Tuple,
Union)
import torch
from transformers import PretrainedConfig
@ -1114,6 +1115,25 @@ class VisionLanguageConfig:
f"Expecting to choose from "
f"{[x.name for x in cls.ImageInputType]}.") from e
def as_cli_args_dict(self) -> Dict[str, Any]:
"""Flatten vision language config to pure args.
Compatible with what llm entrypoint expects.
"""
result: Dict[str, Any] = {}
for f in fields(self):
value = getattr(self, f.name)
if isinstance(value, enum.Enum):
result[f.name] = value.name.lower()
elif isinstance(value, tuple):
result[f.name] = ",".join([str(item) for item in value])
else:
result[f.name] = value
result["disable_image_processor"] = self.image_processor is None
return result
_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.float16,

View File

@ -75,6 +75,14 @@ class ImagePixelData(MultiModalData):
self.image = image
def __repr__(self) -> str:
image = self.image
if isinstance(image, Image.Image):
return f"{type(self).__name__}(image={image})"
return (f"{type(self).__name__}(image=torch.Tensor(shape="
f"{image.shape}, dtype={image.dtype}))")
class ImagePixelPlugin(MultiModalPlugin[ImagePixelData]):
@ -96,10 +104,10 @@ class ImagePixelPlugin(MultiModalPlugin[ImagePixelData]):
self, data: ImagePixelData, model_config: ModelConfig,
vlm_config: VisionLanguageConfig) -> Dict[str, torch.Tensor]:
image = data.image
image_processor = self._get_hf_image_processor(model_config,
vlm_config)
if isinstance(image, Image.Image):
image_processor = self._get_hf_image_processor(
model_config, vlm_config)
if image_processor is None:
raise RuntimeError("No HuggingFace processor is available"
"to process the image object")
@ -127,6 +135,12 @@ class ImageFeatureData(MultiModalData):
def __init__(self, image_features: torch.Tensor) -> None:
self.image_features = image_features
def __repr__(self) -> str:
image_features = self.image_features
return (f"{type(self).__name__}(image_features=torch.Tensor(shape="
f"{image_features.shape}, dtype={image_features.dtype}))")
class ImageFeaturePlugin(MultiModalPlugin[ImageFeatureData]):