[CI/Build] Update vision tests (#5307)

This commit is contained in:
Cyrus Leung 2024-06-06 18:17:18 +08:00 committed by GitHub
parent 7b0a0dfb22
commit 89c920785f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 90 additions and 88 deletions

View File

@ -93,14 +93,13 @@ steps:
- label: Models Test - label: Models Test
#mirror_hardwares: [amd] #mirror_hardwares: [amd]
commands: commands:
- bash ../.buildkite/download-images.sh - pytest -v -s models -m \"not llava\"
- pytest -v -s models --ignore=models/test_llava.py
- label: Llava Test - label: Llava Test
mirror_hardwares: [amd] mirror_hardwares: [amd]
commands: commands:
- bash ../.buildkite/download-images.sh - bash ../.buildkite/download-images.sh
- pytest -v -s models/test_llava.py - pytest -v -s models -m llava
- label: Prefix Caching Test - label: Prefix Caching Test
mirror_hardwares: [amd] mirror_hardwares: [amd]

View File

@ -71,4 +71,5 @@ markers = [
"skip_global_cleanup", "skip_global_cleanup",
"llm: run tests for vLLM API only", "llm: run tests for vLLM API only",
"openai: run tests for OpenAI API only", "openai: run tests for OpenAI API only",
"llava: run tests for LLaVA models only",
] ]

View File

@ -29,24 +29,19 @@ _LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
# Multi modal related # Multi modal related
# You can use `.buildkite/download-images.sh` to download the assets # You can use `.buildkite/download-images.sh` to download the assets
_PIXEL_VALUES_FILES = [ PIXEL_VALUES_FILES = [
os.path.join(_TEST_DIR, "images", filename) for filename in os.path.join(_TEST_DIR, "images", filename) for filename in
["stop_sign_pixel_values.pt", "cherry_blossom_pixel_values.pt"] ["stop_sign_pixel_values.pt", "cherry_blossom_pixel_values.pt"]
] ]
_IMAGE_FEATURES_FILES = [ IMAGE_FEATURES_FILES = [
os.path.join(_TEST_DIR, "images", filename) for filename in os.path.join(_TEST_DIR, "images", filename) for filename in
["stop_sign_image_features.pt", "cherry_blossom_image_features.pt"] ["stop_sign_image_features.pt", "cherry_blossom_image_features.pt"]
] ]
_IMAGE_FILES = [ IMAGE_FILES = [
os.path.join(_TEST_DIR, "images", filename) os.path.join(_TEST_DIR, "images", filename)
for filename in ["stop_sign.jpg", "cherry_blossom.jpg"] for filename in ["stop_sign.jpg", "cherry_blossom.jpg"]
] ]
_IMAGE_PROMPTS = [ assert len(PIXEL_VALUES_FILES) == len(IMAGE_FEATURES_FILES) == len(IMAGE_FILES)
"<image>\nUSER: What's the content of the image?\nASSISTANT:",
"<image>\nUSER: What is the season?\nASSISTANT:"
]
assert len(_PIXEL_VALUES_FILES) == len(_IMAGE_FEATURES_FILES) == len(
_IMAGE_FILES) == len(_IMAGE_PROMPTS)
def _read_prompts(filename: str) -> List[str]: def _read_prompts(filename: str) -> List[str]:
@ -84,14 +79,9 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool):
cleanup() cleanup()
@pytest.fixture(scope="session")
def hf_image_prompts() -> List[str]:
return _IMAGE_PROMPTS
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def hf_images() -> List[Image.Image]: def hf_images() -> List[Image.Image]:
return [Image.open(filename) for filename in _IMAGE_FILES] return [Image.open(filename) for filename in IMAGE_FILES]
@pytest.fixture() @pytest.fixture()
@ -101,26 +91,17 @@ def vllm_images(request) -> List[MultiModalData]:
VisionLanguageConfig.ImageInputType.IMAGE_FEATURES): VisionLanguageConfig.ImageInputType.IMAGE_FEATURES):
return [ return [
ImageFeatureData(torch.load(filename)) ImageFeatureData(torch.load(filename))
for filename in _IMAGE_FEATURES_FILES for filename in IMAGE_FEATURES_FILES
] ]
else: else:
return [ return [
ImagePixelData(Image.open(filename)) for filename in _IMAGE_FILES ImagePixelData(Image.open(filename)) for filename in IMAGE_FILES
] ]
@pytest.fixture() @pytest.fixture()
def vllm_image_tensors(request) -> List[torch.Tensor]: def vllm_image_tensors(request) -> List[torch.Tensor]:
return [torch.load(filename) for filename in _PIXEL_VALUES_FILES] return [torch.load(filename) for filename in PIXEL_VALUES_FILES]
@pytest.fixture()
def vllm_image_prompts(request) -> List[str]:
vision_language_config = request.getfixturevalue("model_and_config")[1]
return [
"<image>" * (vision_language_config.image_feature_size - 1) + p
for p in _IMAGE_PROMPTS
]
@pytest.fixture @pytest.fixture

View File

@ -1,14 +1,22 @@
import gc from typing import List, Tuple
from dataclasses import fields
from enum import Enum
from typing import Any, Dict, List, Tuple
import pytest import pytest
import torch
from transformers import AutoTokenizer from transformers import AutoTokenizer
from vllm.config import VisionLanguageConfig from vllm.config import VisionLanguageConfig
from ..conftest import IMAGE_FILES
pytestmark = pytest.mark.llava
# The image token is placed before "user" on purpose so that the test can pass
HF_IMAGE_PROMPTS = [
"<image>\nUSER: What's the content of the image?\nASSISTANT:",
"<image>\nUSER: What is the season?\nASSISTANT:",
]
assert len(HF_IMAGE_PROMPTS) == len(IMAGE_FILES)
def iter_llava_configs(model_name: str): def iter_llava_configs(model_name: str):
image_hw_to_feature_size = { image_hw_to_feature_size = {
@ -36,53 +44,35 @@ model_and_vl_config = [
] ]
def as_dict(vlm_config: VisionLanguageConfig) -> Dict[str, Any]: def vllm_to_hf_output(vllm_output: Tuple[List[int], str],
"""Flatten vision language config to pure args. vlm_config: VisionLanguageConfig, model_id: str):
Compatible with what llm entrypoint expects.
"""
result = {}
for field in fields(vlm_config):
value = getattr(vlm_config, field.name)
if isinstance(value, Enum):
result[field.name] = value.name.lower()
elif isinstance(value, tuple):
result[field.name] = ",".join([str(item) for item in value])
else:
result[field.name] = value
result["disable_image_processor"] = vlm_config.image_processor is None
return result
def sanitize_vllm_output(vllm_output: Tuple[List[int], str],
vision_language_config: VisionLanguageConfig,
model_id: str):
"""Sanitize vllm output to be comparable with hf output. """Sanitize vllm output to be comparable with hf output.
The function reduces `input_ids` from 1, 32000, 32000, ..., 32000, The function reduces `input_ids` from 1, 32000, 32000, ..., 32000,
x1, x2, x3 ... to 1, 32000, x1, x2, x3 ... x1, x2, x3 ... to 1, 32000, x1, x2, x3 ...
It also reduces `output_str` from "<image><image>bla" to "bla". It also reduces `output_str` from "<image><image>bla" to "bla".
""" """
tokenizer = AutoTokenizer.from_pretrained(model_id)
image_token_str = tokenizer.decode(vision_language_config.image_token_id)
image_token_str_len = len(image_token_str)
input_ids, output_str = vllm_output input_ids, output_str = vllm_output
sanitized_input_ids = input_ids[0:2] + input_ids[2 + vision_language_config image_token_id = vlm_config.image_token_id
.image_feature_size - 1:]
sanitzied_output_str = output_str[vision_language_config. tokenizer = AutoTokenizer.from_pretrained(model_id)
image_feature_size * image_token_str = tokenizer.decode(image_token_id)
image_token_str_len:]
return sanitized_input_ids, sanitzied_output_str hf_input_ids = [
input_id for idx, input_id in enumerate(input_ids)
if input_id != image_token_id or input_ids[idx - 1] != image_token_id
]
hf_output_str = output_str \
.replace(image_token_str * vlm_config.image_feature_size, "")
return hf_input_ids, hf_output_str
@pytest.mark.parametrize("worker_use_ray", [False]) # TODO: Add test for `tensor_parallel_size` [ref: PR #3883]
@pytest.mark.parametrize("model_and_config", model_and_vl_config) @pytest.mark.parametrize("model_and_config", model_and_vl_config)
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("max_tokens", [128])
def test_models(hf_runner, vllm_runner, hf_image_prompts, hf_images, def test_models(hf_runner, vllm_runner, hf_images, vllm_images,
vllm_image_prompts, vllm_images, model_and_config, dtype: str, model_and_config, dtype: str, max_tokens: int) -> None:
max_tokens: int, worker_use_ray: bool) -> None:
"""Inference result should be the same between hf and vllm. """Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images. All the image fixtures for the test is under tests/images.
@ -92,36 +82,33 @@ def test_models(hf_runner, vllm_runner, hf_image_prompts, hf_images,
Note, the text input is also adjusted to abide by vllm contract. Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf. The text output is sanitized to be able to compare with hf.
""" """
model_id, vision_language_config = model_and_config model_id, vlm_config = model_and_config
hf_model = hf_runner(model_id, dtype=dtype, is_vision_model=True) hf_model = hf_runner(model_id, dtype=dtype, is_vision_model=True)
hf_outputs = hf_model.generate_greedy(hf_image_prompts, hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS,
max_tokens, max_tokens,
images=hf_images) images=hf_images)
del hf_model del hf_model
vllm_image_prompts = [
p.replace("<image>", "<image>" * vlm_config.image_feature_size)
for p in HF_IMAGE_PROMPTS
]
vllm_model = vllm_runner(model_id, vllm_model = vllm_runner(model_id,
dtype=dtype, dtype=dtype,
worker_use_ray=worker_use_ray,
enforce_eager=True, enforce_eager=True,
**as_dict(vision_language_config)) **vlm_config.as_cli_args_dict())
vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts, vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts,
max_tokens, max_tokens,
images=vllm_images) images=vllm_images)
del vllm_model del vllm_model
gc.collect() for i in range(len(HF_IMAGE_PROMPTS)):
torch.cuda.empty_cache()
for i in range(len(hf_image_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i] hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = sanitize_vllm_output( vllm_output_ids, vllm_output_str = vllm_to_hf_output(
vllm_outputs[i], vision_language_config, model_id) vllm_outputs[i], vlm_config, model_id)
assert hf_output_str == vllm_output_str, ( assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, ( assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
# TODO: Add test for `tensor_parallel_size` [ref: PR #3883]
# (Requires multiple GPUs)

View File

@ -1,7 +1,8 @@
import enum import enum
import json import json
from dataclasses import dataclass, field, fields from dataclasses import dataclass, field, fields
from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Union from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Tuple,
Union)
import torch import torch
from transformers import PretrainedConfig from transformers import PretrainedConfig
@ -1114,6 +1115,25 @@ class VisionLanguageConfig:
f"Expecting to choose from " f"Expecting to choose from "
f"{[x.name for x in cls.ImageInputType]}.") from e f"{[x.name for x in cls.ImageInputType]}.") from e
def as_cli_args_dict(self) -> Dict[str, Any]:
"""Flatten vision language config to pure args.
Compatible with what llm entrypoint expects.
"""
result: Dict[str, Any] = {}
for f in fields(self):
value = getattr(self, f.name)
if isinstance(value, enum.Enum):
result[f.name] = value.name.lower()
elif isinstance(value, tuple):
result[f.name] = ",".join([str(item) for item in value])
else:
result[f.name] = value
result["disable_image_processor"] = self.image_processor is None
return result
_STR_DTYPE_TO_TORCH_DTYPE = { _STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.float16, "half": torch.float16,

View File

@ -75,6 +75,14 @@ class ImagePixelData(MultiModalData):
self.image = image self.image = image
def __repr__(self) -> str:
image = self.image
if isinstance(image, Image.Image):
return f"{type(self).__name__}(image={image})"
return (f"{type(self).__name__}(image=torch.Tensor(shape="
f"{image.shape}, dtype={image.dtype}))")
class ImagePixelPlugin(MultiModalPlugin[ImagePixelData]): class ImagePixelPlugin(MultiModalPlugin[ImagePixelData]):
@ -96,10 +104,10 @@ class ImagePixelPlugin(MultiModalPlugin[ImagePixelData]):
self, data: ImagePixelData, model_config: ModelConfig, self, data: ImagePixelData, model_config: ModelConfig,
vlm_config: VisionLanguageConfig) -> Dict[str, torch.Tensor]: vlm_config: VisionLanguageConfig) -> Dict[str, torch.Tensor]:
image = data.image image = data.image
image_processor = self._get_hf_image_processor(model_config,
vlm_config)
if isinstance(image, Image.Image): if isinstance(image, Image.Image):
image_processor = self._get_hf_image_processor(
model_config, vlm_config)
if image_processor is None: if image_processor is None:
raise RuntimeError("No HuggingFace processor is available" raise RuntimeError("No HuggingFace processor is available"
"to process the image object") "to process the image object")
@ -127,6 +135,12 @@ class ImageFeatureData(MultiModalData):
def __init__(self, image_features: torch.Tensor) -> None: def __init__(self, image_features: torch.Tensor) -> None:
self.image_features = image_features self.image_features = image_features
def __repr__(self) -> str:
image_features = self.image_features
return (f"{type(self).__name__}(image_features=torch.Tensor(shape="
f"{image_features.shape}, dtype={image_features.dtype}))")
class ImageFeaturePlugin(MultiModalPlugin[ImageFeatureData]): class ImageFeaturePlugin(MultiModalPlugin[ImageFeatureData]):