mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-28 09:27:14 +08:00
[Bugfix] Comprehensively test and fix LLaVA-NeXT feature size calculation (#11800)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
8082ad7950
commit
8f37be38eb
@ -13,6 +13,7 @@ einops # required for MPT, qwen-vl and Mamba
|
||||
httpx
|
||||
librosa # required for audio tests
|
||||
peft
|
||||
pqdm
|
||||
ray[adag]==2.40.0
|
||||
sentence-transformers # required for embedding tests
|
||||
soundfile # required for audio tests
|
||||
|
||||
@ -48,6 +48,8 @@ botocore==1.35.57
|
||||
# awscli
|
||||
# boto3
|
||||
# s3transfer
|
||||
bounded-pool-executor==0.0.3
|
||||
# via pqdm
|
||||
buildkite-test-collector==0.1.9
|
||||
# via -r requirements-test.in
|
||||
certifi==2024.8.30
|
||||
@ -342,6 +344,8 @@ pooch==1.8.2
|
||||
# via librosa
|
||||
portalocker==2.10.1
|
||||
# via sacrebleu
|
||||
pqdm==0.2.0
|
||||
# via -r requirements-test.in
|
||||
propcache==0.2.0
|
||||
# via yarl
|
||||
protobuf==5.28.3
|
||||
|
||||
@ -1,8 +1,13 @@
|
||||
import itertools
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
from PIL import Image
|
||||
from pqdm.threads import pqdm
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.inputs import InputProcessingContext
|
||||
from vllm.multimodal.parse import ImageSize
|
||||
|
||||
from ....utils import build_model_context
|
||||
|
||||
@ -15,20 +20,69 @@ def processor_for_llava_next():
|
||||
return LlavaNextMultiModalProcessor
|
||||
|
||||
|
||||
def _validate_image_prompt_replacements_one(
|
||||
processor,
|
||||
num_imgs: int,
|
||||
failed_size_excs: list[tuple[ImageSize, Exception]],
|
||||
image_size: ImageSize,
|
||||
) -> None:
|
||||
prompt = "<image>" * num_imgs
|
||||
image = Image.new("RGB", size=image_size)
|
||||
mm_data = {"image": [image] * num_imgs}
|
||||
|
||||
try:
|
||||
# The processor will throw an error if there is a mismatch
|
||||
# in the prompt replacements
|
||||
processed_inputs = processor.apply(prompt, mm_data, {})
|
||||
|
||||
image_placeholders = processed_inputs["mm_placeholders"]["image"]
|
||||
assert len(image_placeholders) == num_imgs
|
||||
|
||||
first_placeholder = image_placeholders[0]
|
||||
|
||||
# NOTE: There is a BOS token
|
||||
assert first_placeholder["offset"] == 1
|
||||
assert first_placeholder["length"] == (
|
||||
len(processed_inputs["prompt_token_ids"]) - 1) // num_imgs
|
||||
|
||||
except Exception as exc:
|
||||
failed_size_excs.append((image_size, exc))
|
||||
|
||||
|
||||
def _test_image_prompt_replacements(
|
||||
processor,
|
||||
*,
|
||||
num_imgs: int,
|
||||
image_sizes: list[ImageSize],
|
||||
) -> None:
|
||||
"""
|
||||
Ensure LlavaNextMultiModalProcessor
|
||||
handles prompt replacement properly for input images.
|
||||
"""
|
||||
failed_size_excs = list[tuple[ImageSize, Exception]]()
|
||||
|
||||
validate_one = partial(
|
||||
_validate_image_prompt_replacements_one,
|
||||
processor,
|
||||
num_imgs,
|
||||
failed_size_excs,
|
||||
)
|
||||
pqdm(image_sizes, validate_one, n_jobs=8, desc="Validating image sizes")
|
||||
|
||||
if failed_size_excs:
|
||||
msg = "Found failing image sizes:" \
|
||||
+ "\n========\n".join(f"[{size}]\n{exc}"
|
||||
for size, exc in failed_size_excs)
|
||||
raise AssertionError(msg)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
|
||||
@pytest.mark.parametrize("image_size", [(1669, 2560), (2560, 1669), (183, 488),
|
||||
(488, 183), (198, 176), (176, 198),
|
||||
(161, 184), (184, 161)])
|
||||
@pytest.mark.parametrize("num_imgs", [1, 2])
|
||||
def test_processor_prompt_replacements(
|
||||
def test_processor_prompt_replacements_regression(
|
||||
processor_for_llava_next,
|
||||
model_id: str,
|
||||
image_size: tuple[int, int],
|
||||
num_imgs: int,
|
||||
):
|
||||
"""
|
||||
Ensure LlavaNextMultiModalProcessor handles prompt replacement properly.
|
||||
"""
|
||||
ctx = build_model_context(
|
||||
model_name=model_id,
|
||||
tokenizer_name=model_id,
|
||||
@ -37,22 +91,55 @@ def test_processor_prompt_replacements(
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
||||
ctx = InputProcessingContext(ctx.model_config, tokenizer)
|
||||
|
||||
# Build the image str / prompt based on the number of images we pass
|
||||
prompt = "<image>" * num_imgs
|
||||
mm_data = {"image": [Image.new("RGB", size=image_size)] * num_imgs}
|
||||
|
||||
# The processor will throw an error if there is a mismatch
|
||||
# in the prompt replacements
|
||||
processor = processor_for_llava_next(ctx)
|
||||
processed_inputs = processor.apply(prompt, mm_data, {})
|
||||
|
||||
image_placeholders = processed_inputs["mm_placeholders"]["image"]
|
||||
assert len(image_placeholders) == num_imgs
|
||||
image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328),
|
||||
(488, 183), (2560, 1669)]
|
||||
image_sizes = [
|
||||
size for w, h in image_ratios
|
||||
for size in [ImageSize(w, h), ImageSize(h, w)]
|
||||
]
|
||||
|
||||
first_placeholder = image_placeholders[0]
|
||||
_test_image_prompt_replacements(
|
||||
processor,
|
||||
num_imgs=num_imgs,
|
||||
image_sizes=image_sizes,
|
||||
)
|
||||
|
||||
# NOTE: There is a BOS token
|
||||
assert first_placeholder["offset"] == 1
|
||||
assert first_placeholder["length"] == (
|
||||
len(processed_inputs["prompt_token_ids"]) - 1) // num_imgs
|
||||
|
||||
@pytest.mark.skip("This test takes around 2 hours to run. "
|
||||
"Comment this out to run it manually.")
|
||||
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
|
||||
@pytest.mark.parametrize("num_imgs", [1])
|
||||
def test_processor_prompt_replacements_all(
|
||||
processor_for_llava_next,
|
||||
model_id: str,
|
||||
num_imgs: int,
|
||||
):
|
||||
ctx = build_model_context(
|
||||
model_name=model_id,
|
||||
tokenizer_name=model_id,
|
||||
mm_processor_kwargs=None,
|
||||
limit_mm_per_prompt={"image": num_imgs},
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
||||
ctx = InputProcessingContext(ctx.model_config, tokenizer)
|
||||
processor = processor_for_llava_next(ctx)
|
||||
|
||||
seen_aspect_ratios = set[float]()
|
||||
image_sizes = list[ImageSize]()
|
||||
|
||||
# The aspect ratio of the grid layout is between 1 and 2
|
||||
# NOTE: Assumes that feature size calculation is the same if we
|
||||
# swap the width and height of the image
|
||||
for w, h in itertools.product(range(64, 1024), repeat=2):
|
||||
aspect_ratio = w / h
|
||||
if 1 <= aspect_ratio <= 2 and aspect_ratio not in seen_aspect_ratios:
|
||||
image_sizes.append(ImageSize(w, h))
|
||||
seen_aspect_ratios.add(aspect_ratio)
|
||||
|
||||
_test_image_prompt_replacements(
|
||||
processor,
|
||||
num_imgs=num_imgs,
|
||||
image_sizes=image_sizes,
|
||||
)
|
||||
|
||||
@ -1,8 +1,13 @@
|
||||
import itertools
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
from PIL import Image
|
||||
from pqdm.threads import pqdm
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.inputs import InputProcessingContext
|
||||
from vllm.multimodal.parse import ImageSize
|
||||
|
||||
from ....utils import build_model_context
|
||||
|
||||
@ -15,22 +20,68 @@ def processor_for_llava_onevision():
|
||||
return LlavaOnevisionMultiModalProcessor
|
||||
|
||||
|
||||
def _validate_image_prompt_replacements_one(
|
||||
processor,
|
||||
num_imgs: int,
|
||||
failed_size_excs: list[tuple[ImageSize, Exception]],
|
||||
image_size: ImageSize,
|
||||
) -> None:
|
||||
prompt = "<image>" * num_imgs
|
||||
image = Image.new("RGB", size=image_size)
|
||||
mm_data = {"image": [image] * num_imgs}
|
||||
|
||||
try:
|
||||
# The processor will throw an error if there is a mismatch
|
||||
# in the prompt replacements
|
||||
processed_inputs = processor.apply(prompt, mm_data, {})
|
||||
|
||||
image_placeholders = processed_inputs["mm_placeholders"]["image"]
|
||||
assert len(image_placeholders) == num_imgs
|
||||
|
||||
first_placeholder = image_placeholders[0]
|
||||
|
||||
assert first_placeholder["offset"] == 0
|
||||
assert first_placeholder["length"] == len(
|
||||
processed_inputs["prompt_token_ids"]) // num_imgs
|
||||
except Exception as exc:
|
||||
failed_size_excs.append((image_size, exc))
|
||||
|
||||
|
||||
def _test_image_prompt_replacements(
|
||||
processor,
|
||||
*,
|
||||
num_imgs: int,
|
||||
image_sizes: list[ImageSize],
|
||||
) -> None:
|
||||
"""
|
||||
Ensure LlavaOnevisionMultiModalProcessor
|
||||
handles prompt replacement properly for input images.
|
||||
"""
|
||||
failed_size_excs = list[tuple[ImageSize, Exception]]()
|
||||
|
||||
validate_one = partial(
|
||||
_validate_image_prompt_replacements_one,
|
||||
processor,
|
||||
num_imgs,
|
||||
failed_size_excs,
|
||||
)
|
||||
pqdm(image_sizes, validate_one, n_jobs=8, desc="Validating image sizes")
|
||||
|
||||
if failed_size_excs:
|
||||
msg = "Found failing image sizes:" \
|
||||
+ "\n========\n".join(f"[{size}]\n{exc}"
|
||||
for size, exc in failed_size_excs)
|
||||
raise AssertionError(msg)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id",
|
||||
["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"])
|
||||
@pytest.mark.parametrize("image_size", [(1669, 2560), (2560, 1669), (183, 488),
|
||||
(488, 183), (198, 176), (176, 198),
|
||||
(161, 184), (184, 161)])
|
||||
@pytest.mark.parametrize("num_imgs", [1, 2])
|
||||
def test_processor_prompt_replacements(
|
||||
def test_processor_prompt_replacements_regression(
|
||||
processor_for_llava_onevision,
|
||||
model_id: str,
|
||||
image_size: tuple[int, int],
|
||||
num_imgs: int,
|
||||
):
|
||||
"""
|
||||
Ensure LlavaOnevisionMultiModalProcessor handles prompt replacement
|
||||
properly.
|
||||
"""
|
||||
ctx = build_model_context(
|
||||
model_name=model_id,
|
||||
tokenizer_name=model_id,
|
||||
@ -39,22 +90,56 @@ def test_processor_prompt_replacements(
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
||||
ctx = InputProcessingContext(ctx.model_config, tokenizer)
|
||||
|
||||
# Build the image str / prompt based on the number of images we pass
|
||||
prompt = "<image>" * num_imgs
|
||||
mm_data = {"image": [Image.new("RGB", size=image_size)] * num_imgs}
|
||||
|
||||
# The processor will throw an error if there is a mismatch
|
||||
# in the prompt replacements
|
||||
processor = processor_for_llava_onevision(ctx)
|
||||
processed_inputs = processor.apply(prompt, mm_data, {})
|
||||
|
||||
image_placeholders = processed_inputs["mm_placeholders"]["image"]
|
||||
assert len(image_placeholders) == num_imgs
|
||||
image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328),
|
||||
(488, 183), (2560, 1669)]
|
||||
image_sizes = [
|
||||
size for w, h in image_ratios
|
||||
for size in [ImageSize(w, h), ImageSize(h, w)]
|
||||
]
|
||||
|
||||
first_placeholder = image_placeholders[0]
|
||||
_test_image_prompt_replacements(
|
||||
processor,
|
||||
num_imgs=num_imgs,
|
||||
image_sizes=image_sizes,
|
||||
)
|
||||
|
||||
# NOTE: There is a BOS token
|
||||
assert first_placeholder["offset"] == 0
|
||||
assert first_placeholder["length"] == len(
|
||||
processed_inputs["prompt_token_ids"]) // num_imgs
|
||||
|
||||
@pytest.mark.skip("This test takes around 2 hours to run. "
|
||||
"Comment this out to run it manually.")
|
||||
@pytest.mark.parametrize("model_id",
|
||||
["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"])
|
||||
@pytest.mark.parametrize("num_imgs", [1])
|
||||
def test_processor_prompt_replacements_all(
|
||||
processor_for_llava_onevision,
|
||||
model_id: str,
|
||||
num_imgs: int,
|
||||
):
|
||||
ctx = build_model_context(
|
||||
model_name=model_id,
|
||||
tokenizer_name=model_id,
|
||||
mm_processor_kwargs=None,
|
||||
limit_mm_per_prompt={"image": num_imgs},
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
||||
ctx = InputProcessingContext(ctx.model_config, tokenizer)
|
||||
processor = processor_for_llava_onevision(ctx)
|
||||
|
||||
seen_aspect_ratios = set[float]()
|
||||
image_sizes = list[ImageSize]()
|
||||
|
||||
# The aspect ratio of the grid layout is between 1 and 6
|
||||
# NOTE: Assumes that feature size calculation is the same if we
|
||||
# swap the width and height of the image
|
||||
for w, h in itertools.product(range(64, 1024), repeat=2):
|
||||
aspect_ratio = w / h
|
||||
if 1 <= aspect_ratio <= 6 and aspect_ratio not in seen_aspect_ratios:
|
||||
image_sizes.append(ImageSize(w, h))
|
||||
seen_aspect_ratios.add(aspect_ratio)
|
||||
|
||||
_test_image_prompt_replacements(
|
||||
processor,
|
||||
num_imgs=num_imgs,
|
||||
image_sizes=image_sizes,
|
||||
)
|
||||
|
||||
@ -2,7 +2,6 @@ from functools import cached_property
|
||||
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
|
||||
Protocol, Set, Tuple, TypedDict, Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import BatchFeature, LlavaNextConfig, LlavaNextProcessor
|
||||
@ -74,7 +73,7 @@ class LlavaNextProcessingMixin(BaseLlavaProcessingMixin):
|
||||
def _get_hf_processor(self):
|
||||
return self.ctx.get_hf_processor(LlavaNextProcessor)
|
||||
|
||||
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106
|
||||
# Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L113
|
||||
def _get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
@ -111,7 +110,7 @@ class LlavaNextProcessingMixin(BaseLlavaProcessingMixin):
|
||||
|
||||
return unpadded_feature_size + newline_feature_size + base_feature_size
|
||||
|
||||
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
|
||||
# Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86
|
||||
def _get_num_unpadded_features(
|
||||
self,
|
||||
*,
|
||||
@ -121,29 +120,23 @@ class LlavaNextProcessingMixin(BaseLlavaProcessingMixin):
|
||||
num_patch_height: int,
|
||||
num_patch_width: int,
|
||||
) -> tuple[int, int]:
|
||||
# NOTE: Use float32 to remain consistent with HF output
|
||||
current_height_f = np.float32(npatches * num_patch_height)
|
||||
current_width_f = np.float32(npatches * num_patch_width)
|
||||
current_height = npatches * num_patch_height
|
||||
current_width = npatches * num_patch_width
|
||||
|
||||
original_width_f = np.float32(original_width)
|
||||
original_height_f = np.float32(original_height)
|
||||
aspect_ratio = original_width / original_height
|
||||
current_aspect_ratio = current_width / current_height
|
||||
|
||||
original_aspect_ratio = original_width_f / original_height_f
|
||||
current_aspect_ratio = current_width_f / current_height_f
|
||||
|
||||
if original_aspect_ratio > current_aspect_ratio:
|
||||
scale_factor = current_width_f / original_width_f
|
||||
new_height = int(original_height_f * scale_factor)
|
||||
padding = (current_height_f - new_height) // 2
|
||||
current_height_f -= 2 * padding
|
||||
if aspect_ratio > current_aspect_ratio:
|
||||
new_height = (original_height * current_width) // original_width
|
||||
padding = (current_height - new_height) // 2
|
||||
current_height = current_height - (2 * padding)
|
||||
else:
|
||||
scale_factor = current_height_f / original_height_f
|
||||
new_width = int(original_width_f * scale_factor)
|
||||
padding = (current_width_f - new_width) // 2
|
||||
current_width_f -= 2 * padding
|
||||
new_width = (original_width * current_height) // original_height
|
||||
padding = (current_width - new_width) // 2
|
||||
current_width = current_width - (2 * padding)
|
||||
|
||||
unpadded_features = int(current_height_f * current_width_f)
|
||||
newline_features = int(current_height_f)
|
||||
unpadded_features = current_height * current_width
|
||||
newline_features = current_height
|
||||
|
||||
return (unpadded_features, newline_features)
|
||||
|
||||
|
||||
@ -3,7 +3,6 @@ from functools import cached_property
|
||||
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
|
||||
Protocol, Set, Tuple, TypedDict, Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import (BatchFeature, LlavaOnevisionConfig,
|
||||
@ -98,6 +97,8 @@ class LlavaOnevisionProcessingMixin(LlavaNextProcessingMixin):
|
||||
def _get_hf_processor(self):
|
||||
return self.ctx.get_hf_processor(LlavaOnevisionProcessor)
|
||||
|
||||
# Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86
|
||||
# with additional logic afterwards taken from LlavaOnevisionProcessor
|
||||
def _get_num_unpadded_features(
|
||||
self,
|
||||
*,
|
||||
@ -107,35 +108,28 @@ class LlavaOnevisionProcessingMixin(LlavaNextProcessingMixin):
|
||||
num_patch_height: int,
|
||||
num_patch_width: int,
|
||||
) -> tuple[int, int]:
|
||||
# NOTE: Use float32 to remain consistent with HF output
|
||||
current_height_f = np.float32(npatches * num_patch_height)
|
||||
current_width_f = np.float32(npatches * num_patch_width)
|
||||
current_height = npatches * num_patch_height
|
||||
current_width = npatches * num_patch_width
|
||||
|
||||
original_width_f = np.float32(original_width)
|
||||
original_height_f = np.float32(original_height)
|
||||
aspect_ratio = original_width / original_height
|
||||
current_aspect_ratio = current_width / current_height
|
||||
|
||||
original_aspect_ratio = original_width_f / original_height_f
|
||||
current_aspect_ratio = current_width_f / current_height_f
|
||||
|
||||
if original_aspect_ratio > current_aspect_ratio:
|
||||
scale_factor = current_width_f / original_width_f
|
||||
new_height = int(original_height_f * scale_factor)
|
||||
padding = (current_height_f - new_height) // 2
|
||||
current_height_f -= 2 * padding
|
||||
if aspect_ratio > current_aspect_ratio:
|
||||
new_height = (original_height * current_width) // original_width
|
||||
padding = (current_height - new_height) // 2
|
||||
current_height = current_height - (2 * padding)
|
||||
else:
|
||||
scale_factor = current_height_f / original_height_f
|
||||
new_width = int(original_width_f * scale_factor)
|
||||
padding = (current_width_f - new_width) // 2
|
||||
current_width_f -= 2 * padding
|
||||
new_width = (original_width * current_height) // original_height
|
||||
padding = (current_width - new_width) // 2
|
||||
current_width = current_width - (2 * padding)
|
||||
|
||||
unpadded_features = int(current_height_f * current_width_f)
|
||||
newline_features = int(current_height_f)
|
||||
unpadded_features = current_height * current_width
|
||||
newline_features = current_height
|
||||
|
||||
ratio = math.sqrt(current_height_f * current_width_f /
|
||||
(9 * npatches**2))
|
||||
ratio = math.sqrt(current_height * current_width / (9 * npatches**2))
|
||||
if ratio > 1.1:
|
||||
height_factor = int(current_height_f // ratio)
|
||||
width_factor = int(current_width_f // ratio)
|
||||
height_factor = int(current_height // ratio)
|
||||
width_factor = int(current_width // ratio)
|
||||
unpadded_features = height_factor * width_factor
|
||||
newline_features = height_factor
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user