mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-20 13:06:03 +08:00
[Bugfix] Schedule failure due to wrong get_image_size_with_most_features (#29692)
This commit is contained in:
parent
302b2c1eb9
commit
f90319d5d1
42
tests/models/multimodal/processing/test_gemma3.py
Normal file
42
tests/models/multimodal/processing/test_gemma3.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
|
||||||
|
from ....conftest import ImageTestAssets
|
||||||
|
from ...utils import build_model_context
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_id", ["google/gemma-3-4b-it"])
|
||||||
|
def test_get_image_size_with_most_features(
|
||||||
|
image_assets: ImageTestAssets, model_id: str
|
||||||
|
):
|
||||||
|
ctx = build_model_context(
|
||||||
|
model_id,
|
||||||
|
mm_processor_kwargs={"do_pan_and_scan": True},
|
||||||
|
limit_mm_per_prompt={"image": 1},
|
||||||
|
)
|
||||||
|
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
|
||||||
|
|
||||||
|
hf_processor_mm_kwargs: dict[str, object] = {}
|
||||||
|
hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||||
|
|
||||||
|
max_image_size = processor.info.get_image_size_with_most_features()
|
||||||
|
max_tokens = processor.info.get_num_image_tokens(
|
||||||
|
image_width=max_image_size.width,
|
||||||
|
image_height=max_image_size.height,
|
||||||
|
processor=hf_processor,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = "<start_of_image>"
|
||||||
|
image_seq_length = hf_processor.image_seq_length
|
||||||
|
|
||||||
|
for asset in image_assets:
|
||||||
|
mm_data = {"image": [asset.pil_image]}
|
||||||
|
processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs)
|
||||||
|
mm_kwargs_data = processed_inputs["mm_kwargs"].get_data()
|
||||||
|
num_patches_tensor = mm_kwargs_data["num_patches"]
|
||||||
|
tokens = int(num_patches_tensor.item()) * image_seq_length
|
||||||
|
assert tokens <= max_tokens
|
||||||
@ -53,3 +53,38 @@ def test_processor_override(
|
|||||||
assert img_tok_count == expected_toks_per_img * num_imgs
|
assert img_tok_count == expected_toks_per_img * num_imgs
|
||||||
assert pixel_shape[0] == expected_pixels_shape[0] * num_imgs
|
assert pixel_shape[0] == expected_pixels_shape[0] * num_imgs
|
||||||
assert pixel_shape[1] == expected_pixels_shape[1]
|
assert pixel_shape[1] == expected_pixels_shape[1]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"])
|
||||||
|
@pytest.mark.parametrize("max_pixels", [1280 * 28 * 28, 1283 * 28 * 28])
|
||||||
|
def test_get_image_size_with_most_features(
|
||||||
|
image_assets: ImageTestAssets,
|
||||||
|
model_id: str,
|
||||||
|
max_pixels: int,
|
||||||
|
):
|
||||||
|
ctx = build_model_context(
|
||||||
|
model_id,
|
||||||
|
mm_processor_kwargs={"max_pixels": max_pixels},
|
||||||
|
limit_mm_per_prompt={"image": 1},
|
||||||
|
)
|
||||||
|
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
|
||||||
|
|
||||||
|
hf_processor_mm_kwargs: dict[str, object] = {}
|
||||||
|
hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||||
|
merge_size = processor.info.get_hf_config().vision_config.spatial_merge_size
|
||||||
|
|
||||||
|
max_image_size = processor.info.get_image_size_with_most_features()
|
||||||
|
max_tokens = processor.info.get_num_image_tokens(
|
||||||
|
image_width=max_image_size.width,
|
||||||
|
image_height=max_image_size.height,
|
||||||
|
image_processor=hf_processor.image_processor,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = "<|vision_start|><|image_pad|><|vision_end|>"
|
||||||
|
for asset in image_assets:
|
||||||
|
mm_data = {"image": [asset.pil_image]}
|
||||||
|
processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs)
|
||||||
|
grid_thw = processed_inputs["mm_kwargs"].get_data()["image_grid_thw"].tolist()
|
||||||
|
t, h, w = grid_thw[0]
|
||||||
|
tokens = (t * h * w) // (merge_size**2)
|
||||||
|
assert tokens < max_tokens
|
||||||
|
|||||||
@ -237,8 +237,9 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
|
|||||||
)
|
)
|
||||||
max_num_crops = images_kwargs["pan_and_scan_max_num_crops"]
|
max_num_crops = images_kwargs["pan_and_scan_max_num_crops"]
|
||||||
|
|
||||||
# Result in the max possible feature size (h:w = max_num_crops:1)
|
vision_config = self.get_hf_config().vision_config
|
||||||
return ImageSize(height=50 * max_num_crops, width=50)
|
native_size = vision_config.image_size
|
||||||
|
return ImageSize(height=native_size * max_num_crops, width=native_size)
|
||||||
|
|
||||||
|
|
||||||
class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):
|
class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):
|
||||||
|
|||||||
@ -25,6 +25,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
|
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
|
||||||
|
|
||||||
|
import math
|
||||||
from collections.abc import Callable, Iterable, Mapping, Sequence
|
from collections.abc import Callable, Iterable, Mapping, Sequence
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Annotated, Any, Literal, TypeAlias
|
from typing import Annotated, Any, Literal, TypeAlias
|
||||||
@ -959,13 +960,42 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
|
|||||||
return num_video_tokens
|
return num_video_tokens
|
||||||
|
|
||||||
def get_image_size_with_most_features(self) -> ImageSize:
|
def get_image_size_with_most_features(self) -> ImageSize:
|
||||||
max_image_size, _ = self._get_vision_info(
|
# NOTE: Simply processing a huge size with _get_vision_info might not give a
|
||||||
image_width=9999999,
|
# size that maximizes the number of featrues, i.e., the number of (merged)
|
||||||
image_height=9999999,
|
# patches. This is because the number of patches limits the allowed aspect
|
||||||
num_frames=1,
|
# ratios. For example, suppose the maximum number of patches is 1280. A square
|
||||||
image_processor=None,
|
# image cannot be broken down into 1280 patches, so feeding a giant square image
|
||||||
)
|
# into _get_vision_info will not yield a size that maximizes the number of
|
||||||
return max_image_size
|
# patches. Therefore, we directly factorize the maximum number of patches into
|
||||||
|
# height and width. The tricky part is to avoid extreme aspect ratios (>200 for
|
||||||
|
# qwen2-vl). If we can't find a suitable aspect ratio, we decrease the number of
|
||||||
|
# patches and retry. This is safe because the processor does not accept extreme
|
||||||
|
# aspect ratios, so there is no valid post-resize image with the number of
|
||||||
|
# patches that yields extreme aspect ratios.
|
||||||
|
|
||||||
|
hf_config = self.get_hf_config()
|
||||||
|
vision_config = hf_config.vision_config
|
||||||
|
patch_size = vision_config.patch_size
|
||||||
|
merge_size = vision_config.spatial_merge_size
|
||||||
|
image_processor = self.get_image_processor()
|
||||||
|
max_pixels = image_processor.max_pixels or image_processor.size["longest_edge"]
|
||||||
|
unit = patch_size * merge_size
|
||||||
|
max_seq_len = max_pixels // (unit * unit)
|
||||||
|
|
||||||
|
def closest_factor_pair(n: int) -> tuple[int, int]:
|
||||||
|
# left <= right
|
||||||
|
for d in range(math.isqrt(n), 0, -1):
|
||||||
|
if n % d == 0:
|
||||||
|
return d, n // d
|
||||||
|
return 1, n
|
||||||
|
|
||||||
|
height_factor, width_factor = 1, max_seq_len
|
||||||
|
for seq_len in range(max_seq_len, 0, -1):
|
||||||
|
height_factor, width_factor = closest_factor_pair(seq_len)
|
||||||
|
if width_factor / height_factor <= 200:
|
||||||
|
break
|
||||||
|
|
||||||
|
return ImageSize(width=unit * width_factor, height=unit * height_factor)
|
||||||
|
|
||||||
def get_max_image_tokens(self) -> int:
|
def get_max_image_tokens(self) -> int:
|
||||||
target_width, target_height = self.get_image_size_with_most_features()
|
target_width, target_height = self.get_image_size_with_most_features()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user