[Bugfix] Schedule failure due to wrong get_image_size_with_most_features (#29692)

This commit is contained in:
Jaehwang Jung 2025-12-12 19:27:20 +09:00 committed by GitHub
parent 302b2c1eb9
commit f90319d5d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 117 additions and 9 deletions

View File

@ -0,0 +1,42 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.multimodal import MULTIMODAL_REGISTRY
from ....conftest import ImageTestAssets
from ...utils import build_model_context
@pytest.mark.parametrize("model_id", ["google/gemma-3-4b-it"])
def test_get_image_size_with_most_features(
image_assets: ImageTestAssets, model_id: str
):
ctx = build_model_context(
model_id,
mm_processor_kwargs={"do_pan_and_scan": True},
limit_mm_per_prompt={"image": 1},
)
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
hf_processor_mm_kwargs: dict[str, object] = {}
hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs)
max_image_size = processor.info.get_image_size_with_most_features()
max_tokens = processor.info.get_num_image_tokens(
image_width=max_image_size.width,
image_height=max_image_size.height,
processor=hf_processor,
)
prompt = "<start_of_image>"
image_seq_length = hf_processor.image_seq_length
for asset in image_assets:
mm_data = {"image": [asset.pil_image]}
processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs)
mm_kwargs_data = processed_inputs["mm_kwargs"].get_data()
num_patches_tensor = mm_kwargs_data["num_patches"]
tokens = int(num_patches_tensor.item()) * image_seq_length
assert tokens <= max_tokens

View File

@ -53,3 +53,38 @@ def test_processor_override(
assert img_tok_count == expected_toks_per_img * num_imgs assert img_tok_count == expected_toks_per_img * num_imgs
assert pixel_shape[0] == expected_pixels_shape[0] * num_imgs assert pixel_shape[0] == expected_pixels_shape[0] * num_imgs
assert pixel_shape[1] == expected_pixels_shape[1] assert pixel_shape[1] == expected_pixels_shape[1]
@pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"])
@pytest.mark.parametrize("max_pixels", [1280 * 28 * 28, 1283 * 28 * 28])
def test_get_image_size_with_most_features(
image_assets: ImageTestAssets,
model_id: str,
max_pixels: int,
):
ctx = build_model_context(
model_id,
mm_processor_kwargs={"max_pixels": max_pixels},
limit_mm_per_prompt={"image": 1},
)
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
hf_processor_mm_kwargs: dict[str, object] = {}
hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs)
merge_size = processor.info.get_hf_config().vision_config.spatial_merge_size
max_image_size = processor.info.get_image_size_with_most_features()
max_tokens = processor.info.get_num_image_tokens(
image_width=max_image_size.width,
image_height=max_image_size.height,
image_processor=hf_processor.image_processor,
)
prompt = "<|vision_start|><|image_pad|><|vision_end|>"
for asset in image_assets:
mm_data = {"image": [asset.pil_image]}
processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs)
grid_thw = processed_inputs["mm_kwargs"].get_data()["image_grid_thw"].tolist()
t, h, w = grid_thw[0]
tokens = (t * h * w) // (merge_size**2)
assert tokens < max_tokens

View File

@ -237,8 +237,9 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
) )
max_num_crops = images_kwargs["pan_and_scan_max_num_crops"] max_num_crops = images_kwargs["pan_and_scan_max_num_crops"]
# Result in the max possible feature size (h:w = max_num_crops:1) vision_config = self.get_hf_config().vision_config
return ImageSize(height=50 * max_num_crops, width=50) native_size = vision_config.image_size
return ImageSize(height=native_size * max_num_crops, width=native_size)
class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]): class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):

View File

@ -25,6 +25,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights.""" """Inference-only Qwen2-VL model compatible with HuggingFace weights."""
import math
from collections.abc import Callable, Iterable, Mapping, Sequence from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import partial from functools import partial
from typing import Annotated, Any, Literal, TypeAlias from typing import Annotated, Any, Literal, TypeAlias
@ -959,13 +960,42 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
return num_video_tokens return num_video_tokens
def get_image_size_with_most_features(self) -> ImageSize: def get_image_size_with_most_features(self) -> ImageSize:
max_image_size, _ = self._get_vision_info( # NOTE: Simply processing a huge size with _get_vision_info might not give a
image_width=9999999, # size that maximizes the number of featrues, i.e., the number of (merged)
image_height=9999999, # patches. This is because the number of patches limits the allowed aspect
num_frames=1, # ratios. For example, suppose the maximum number of patches is 1280. A square
image_processor=None, # image cannot be broken down into 1280 patches, so feeding a giant square image
) # into _get_vision_info will not yield a size that maximizes the number of
return max_image_size # patches. Therefore, we directly factorize the maximum number of patches into
# height and width. The tricky part is to avoid extreme aspect ratios (>200 for
# qwen2-vl). If we can't find a suitable aspect ratio, we decrease the number of
# patches and retry. This is safe because the processor does not accept extreme
# aspect ratios, so there is no valid post-resize image with the number of
# patches that yields extreme aspect ratios.
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
patch_size = vision_config.patch_size
merge_size = vision_config.spatial_merge_size
image_processor = self.get_image_processor()
max_pixels = image_processor.max_pixels or image_processor.size["longest_edge"]
unit = patch_size * merge_size
max_seq_len = max_pixels // (unit * unit)
def closest_factor_pair(n: int) -> tuple[int, int]:
# left <= right
for d in range(math.isqrt(n), 0, -1):
if n % d == 0:
return d, n // d
return 1, n
height_factor, width_factor = 1, max_seq_len
for seq_len in range(max_seq_len, 0, -1):
height_factor, width_factor = closest_factor_pair(seq_len)
if width_factor / height_factor <= 200:
break
return ImageSize(width=unit * width_factor, height=unit * height_factor)
def get_max_image_tokens(self) -> int: def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features() target_width, target_height = self.get_image_size_with_most_features()