[Bugfix] Comprehensively test and fix LLaVA-NeXT feature size calculation (#11800)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-01-07 18:25:02 +08:00 committed by GitHub
parent 8082ad7950
commit 8f37be38eb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 257 additions and 93 deletions

View File

@ -13,6 +13,7 @@ einops # required for MPT, qwen-vl and Mamba
httpx
librosa # required for audio tests
peft
pqdm
ray[adag]==2.40.0
sentence-transformers # required for embedding tests
soundfile # required for audio tests

View File

@ -48,6 +48,8 @@ botocore==1.35.57
# awscli
# boto3
# s3transfer
bounded-pool-executor==0.0.3
# via pqdm
buildkite-test-collector==0.1.9
# via -r requirements-test.in
certifi==2024.8.30
@ -342,6 +344,8 @@ pooch==1.8.2
# via librosa
portalocker==2.10.1
# via sacrebleu
pqdm==0.2.0
# via -r requirements-test.in
propcache==0.2.0
# via yarl
protobuf==5.28.3

View File

@ -1,8 +1,13 @@
import itertools
from functools import partial
import pytest
from PIL import Image
from pqdm.threads import pqdm
from transformers import AutoTokenizer
from vllm.inputs import InputProcessingContext
from vllm.multimodal.parse import ImageSize
from ....utils import build_model_context
@ -15,20 +20,69 @@ def processor_for_llava_next():
return LlavaNextMultiModalProcessor
def _validate_image_prompt_replacements_one(
processor,
num_imgs: int,
failed_size_excs: list[tuple[ImageSize, Exception]],
image_size: ImageSize,
) -> None:
prompt = "<image>" * num_imgs
image = Image.new("RGB", size=image_size)
mm_data = {"image": [image] * num_imgs}
try:
# The processor will throw an error if there is a mismatch
# in the prompt replacements
processed_inputs = processor.apply(prompt, mm_data, {})
image_placeholders = processed_inputs["mm_placeholders"]["image"]
assert len(image_placeholders) == num_imgs
first_placeholder = image_placeholders[0]
# NOTE: There is a BOS token
assert first_placeholder["offset"] == 1
assert first_placeholder["length"] == (
len(processed_inputs["prompt_token_ids"]) - 1) // num_imgs
except Exception as exc:
failed_size_excs.append((image_size, exc))
def _test_image_prompt_replacements(
processor,
*,
num_imgs: int,
image_sizes: list[ImageSize],
) -> None:
"""
Ensure LlavaNextMultiModalProcessor
handles prompt replacement properly for input images.
"""
failed_size_excs = list[tuple[ImageSize, Exception]]()
validate_one = partial(
_validate_image_prompt_replacements_one,
processor,
num_imgs,
failed_size_excs,
)
pqdm(image_sizes, validate_one, n_jobs=8, desc="Validating image sizes")
if failed_size_excs:
msg = "Found failing image sizes:" \
+ "\n========\n".join(f"[{size}]\n{exc}"
for size, exc in failed_size_excs)
raise AssertionError(msg)
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@pytest.mark.parametrize("image_size", [(1669, 2560), (2560, 1669), (183, 488),
(488, 183), (198, 176), (176, 198),
(161, 184), (184, 161)])
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_prompt_replacements(
def test_processor_prompt_replacements_regression(
processor_for_llava_next,
model_id: str,
image_size: tuple[int, int],
num_imgs: int,
):
"""
Ensure LlavaNextMultiModalProcessor handles prompt replacement properly.
"""
ctx = build_model_context(
model_name=model_id,
tokenizer_name=model_id,
@ -37,22 +91,55 @@ def test_processor_prompt_replacements(
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer)
# Build the image str / prompt based on the number of images we pass
prompt = "<image>" * num_imgs
mm_data = {"image": [Image.new("RGB", size=image_size)] * num_imgs}
# The processor will throw an error if there is a mismatch
# in the prompt replacements
processor = processor_for_llava_next(ctx)
processed_inputs = processor.apply(prompt, mm_data, {})
image_placeholders = processed_inputs["mm_placeholders"]["image"]
assert len(image_placeholders) == num_imgs
image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328),
(488, 183), (2560, 1669)]
image_sizes = [
size for w, h in image_ratios
for size in [ImageSize(w, h), ImageSize(h, w)]
]
first_placeholder = image_placeholders[0]
_test_image_prompt_replacements(
processor,
num_imgs=num_imgs,
image_sizes=image_sizes,
)
# NOTE: There is a BOS token
assert first_placeholder["offset"] == 1
assert first_placeholder["length"] == (
len(processed_inputs["prompt_token_ids"]) - 1) // num_imgs
@pytest.mark.skip("This test takes around 2 hours to run. "
"Comment this out to run it manually.")
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@pytest.mark.parametrize("num_imgs", [1])
def test_processor_prompt_replacements_all(
processor_for_llava_next,
model_id: str,
num_imgs: int,
):
ctx = build_model_context(
model_name=model_id,
tokenizer_name=model_id,
mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer)
processor = processor_for_llava_next(ctx)
seen_aspect_ratios = set[float]()
image_sizes = list[ImageSize]()
# The aspect ratio of the grid layout is between 1 and 2
# NOTE: Assumes that feature size calculation is the same if we
# swap the width and height of the image
for w, h in itertools.product(range(64, 1024), repeat=2):
aspect_ratio = w / h
if 1 <= aspect_ratio <= 2 and aspect_ratio not in seen_aspect_ratios:
image_sizes.append(ImageSize(w, h))
seen_aspect_ratios.add(aspect_ratio)
_test_image_prompt_replacements(
processor,
num_imgs=num_imgs,
image_sizes=image_sizes,
)

View File

@ -1,8 +1,13 @@
import itertools
from functools import partial
import pytest
from PIL import Image
from pqdm.threads import pqdm
from transformers import AutoTokenizer
from vllm.inputs import InputProcessingContext
from vllm.multimodal.parse import ImageSize
from ....utils import build_model_context
@ -15,22 +20,68 @@ def processor_for_llava_onevision():
return LlavaOnevisionMultiModalProcessor
def _validate_image_prompt_replacements_one(
processor,
num_imgs: int,
failed_size_excs: list[tuple[ImageSize, Exception]],
image_size: ImageSize,
) -> None:
prompt = "<image>" * num_imgs
image = Image.new("RGB", size=image_size)
mm_data = {"image": [image] * num_imgs}
try:
# The processor will throw an error if there is a mismatch
# in the prompt replacements
processed_inputs = processor.apply(prompt, mm_data, {})
image_placeholders = processed_inputs["mm_placeholders"]["image"]
assert len(image_placeholders) == num_imgs
first_placeholder = image_placeholders[0]
assert first_placeholder["offset"] == 0
assert first_placeholder["length"] == len(
processed_inputs["prompt_token_ids"]) // num_imgs
except Exception as exc:
failed_size_excs.append((image_size, exc))
def _test_image_prompt_replacements(
processor,
*,
num_imgs: int,
image_sizes: list[ImageSize],
) -> None:
"""
Ensure LlavaOnevisionMultiModalProcessor
handles prompt replacement properly for input images.
"""
failed_size_excs = list[tuple[ImageSize, Exception]]()
validate_one = partial(
_validate_image_prompt_replacements_one,
processor,
num_imgs,
failed_size_excs,
)
pqdm(image_sizes, validate_one, n_jobs=8, desc="Validating image sizes")
if failed_size_excs:
msg = "Found failing image sizes:" \
+ "\n========\n".join(f"[{size}]\n{exc}"
for size, exc in failed_size_excs)
raise AssertionError(msg)
@pytest.mark.parametrize("model_id",
["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"])
@pytest.mark.parametrize("image_size", [(1669, 2560), (2560, 1669), (183, 488),
(488, 183), (198, 176), (176, 198),
(161, 184), (184, 161)])
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_prompt_replacements(
def test_processor_prompt_replacements_regression(
processor_for_llava_onevision,
model_id: str,
image_size: tuple[int, int],
num_imgs: int,
):
"""
Ensure LlavaOnevisionMultiModalProcessor handles prompt replacement
properly.
"""
ctx = build_model_context(
model_name=model_id,
tokenizer_name=model_id,
@ -39,22 +90,56 @@ def test_processor_prompt_replacements(
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer)
# Build the image str / prompt based on the number of images we pass
prompt = "<image>" * num_imgs
mm_data = {"image": [Image.new("RGB", size=image_size)] * num_imgs}
# The processor will throw an error if there is a mismatch
# in the prompt replacements
processor = processor_for_llava_onevision(ctx)
processed_inputs = processor.apply(prompt, mm_data, {})
image_placeholders = processed_inputs["mm_placeholders"]["image"]
assert len(image_placeholders) == num_imgs
image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328),
(488, 183), (2560, 1669)]
image_sizes = [
size for w, h in image_ratios
for size in [ImageSize(w, h), ImageSize(h, w)]
]
first_placeholder = image_placeholders[0]
_test_image_prompt_replacements(
processor,
num_imgs=num_imgs,
image_sizes=image_sizes,
)
# NOTE: There is a BOS token
assert first_placeholder["offset"] == 0
assert first_placeholder["length"] == len(
processed_inputs["prompt_token_ids"]) // num_imgs
@pytest.mark.skip("This test takes around 2 hours to run. "
"Comment this out to run it manually.")
@pytest.mark.parametrize("model_id",
["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"])
@pytest.mark.parametrize("num_imgs", [1])
def test_processor_prompt_replacements_all(
processor_for_llava_onevision,
model_id: str,
num_imgs: int,
):
ctx = build_model_context(
model_name=model_id,
tokenizer_name=model_id,
mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer)
processor = processor_for_llava_onevision(ctx)
seen_aspect_ratios = set[float]()
image_sizes = list[ImageSize]()
# The aspect ratio of the grid layout is between 1 and 6
# NOTE: Assumes that feature size calculation is the same if we
# swap the width and height of the image
for w, h in itertools.product(range(64, 1024), repeat=2):
aspect_ratio = w / h
if 1 <= aspect_ratio <= 6 and aspect_ratio not in seen_aspect_ratios:
image_sizes.append(ImageSize(w, h))
seen_aspect_ratios.add(aspect_ratio)
_test_image_prompt_replacements(
processor,
num_imgs=num_imgs,
image_sizes=image_sizes,
)

View File

@ -2,7 +2,6 @@ from functools import cached_property
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
Protocol, Set, Tuple, TypedDict, Union)
import numpy as np
import torch
import torch.nn as nn
from transformers import BatchFeature, LlavaNextConfig, LlavaNextProcessor
@ -74,7 +73,7 @@ class LlavaNextProcessingMixin(BaseLlavaProcessingMixin):
def _get_hf_processor(self):
return self.ctx.get_hf_processor(LlavaNextProcessor)
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106
# Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L113
def _get_num_image_tokens(
self,
*,
@ -111,7 +110,7 @@ class LlavaNextProcessingMixin(BaseLlavaProcessingMixin):
return unpadded_feature_size + newline_feature_size + base_feature_size
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
# Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86
def _get_num_unpadded_features(
self,
*,
@ -121,29 +120,23 @@ class LlavaNextProcessingMixin(BaseLlavaProcessingMixin):
num_patch_height: int,
num_patch_width: int,
) -> tuple[int, int]:
# NOTE: Use float32 to remain consistent with HF output
current_height_f = np.float32(npatches * num_patch_height)
current_width_f = np.float32(npatches * num_patch_width)
current_height = npatches * num_patch_height
current_width = npatches * num_patch_width
original_width_f = np.float32(original_width)
original_height_f = np.float32(original_height)
aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
original_aspect_ratio = original_width_f / original_height_f
current_aspect_ratio = current_width_f / current_height_f
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width_f / original_width_f
new_height = int(original_height_f * scale_factor)
padding = (current_height_f - new_height) // 2
current_height_f -= 2 * padding
if aspect_ratio > current_aspect_ratio:
new_height = (original_height * current_width) // original_width
padding = (current_height - new_height) // 2
current_height = current_height - (2 * padding)
else:
scale_factor = current_height_f / original_height_f
new_width = int(original_width_f * scale_factor)
padding = (current_width_f - new_width) // 2
current_width_f -= 2 * padding
new_width = (original_width * current_height) // original_height
padding = (current_width - new_width) // 2
current_width = current_width - (2 * padding)
unpadded_features = int(current_height_f * current_width_f)
newline_features = int(current_height_f)
unpadded_features = current_height * current_width
newline_features = current_height
return (unpadded_features, newline_features)

View File

@ -3,7 +3,6 @@ from functools import cached_property
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
Protocol, Set, Tuple, TypedDict, Union)
import numpy as np
import torch
import torch.nn as nn
from transformers import (BatchFeature, LlavaOnevisionConfig,
@ -98,6 +97,8 @@ class LlavaOnevisionProcessingMixin(LlavaNextProcessingMixin):
def _get_hf_processor(self):
return self.ctx.get_hf_processor(LlavaOnevisionProcessor)
# Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86
# with additional logic afterwards taken from LlavaOnevisionProcessor
def _get_num_unpadded_features(
self,
*,
@ -107,35 +108,28 @@ class LlavaOnevisionProcessingMixin(LlavaNextProcessingMixin):
num_patch_height: int,
num_patch_width: int,
) -> tuple[int, int]:
# NOTE: Use float32 to remain consistent with HF output
current_height_f = np.float32(npatches * num_patch_height)
current_width_f = np.float32(npatches * num_patch_width)
current_height = npatches * num_patch_height
current_width = npatches * num_patch_width
original_width_f = np.float32(original_width)
original_height_f = np.float32(original_height)
aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
original_aspect_ratio = original_width_f / original_height_f
current_aspect_ratio = current_width_f / current_height_f
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width_f / original_width_f
new_height = int(original_height_f * scale_factor)
padding = (current_height_f - new_height) // 2
current_height_f -= 2 * padding
if aspect_ratio > current_aspect_ratio:
new_height = (original_height * current_width) // original_width
padding = (current_height - new_height) // 2
current_height = current_height - (2 * padding)
else:
scale_factor = current_height_f / original_height_f
new_width = int(original_width_f * scale_factor)
padding = (current_width_f - new_width) // 2
current_width_f -= 2 * padding
new_width = (original_width * current_height) // original_height
padding = (current_width - new_width) // 2
current_width = current_width - (2 * padding)
unpadded_features = int(current_height_f * current_width_f)
newline_features = int(current_height_f)
unpadded_features = current_height * current_width
newline_features = current_height
ratio = math.sqrt(current_height_f * current_width_f /
(9 * npatches**2))
ratio = math.sqrt(current_height * current_width / (9 * npatches**2))
if ratio > 1.1:
height_factor = int(current_height_f // ratio)
width_factor = int(current_width_f // ratio)
height_factor = int(current_height // ratio)
width_factor = int(current_width // ratio)
unpadded_features = height_factor * width_factor
newline_features = height_factor