[VLM] Support pan-and-scan for Gemma3 multi-modal processor (#14672)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Roger Wang <ywang@roblox.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Cyrus Leung 2025-03-13 17:23:12 +08:00 committed by GitHub
parent a73122de96
commit 382403921f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 315 additions and 81 deletions

View File

@ -763,7 +763,7 @@ See [this page](#generative-models) for more information on how to use generativ
* `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc.
* ✅︎
* ✅︎
* ✅︎\*
* ⚠️
- * `GLM4VForCausalLM`<sup>^</sup>
* GLM-4V
* T + I
@ -856,12 +856,12 @@ See [this page](#generative-models) for more information on how to use generativ
* ✅︎
* ✅︎
- * `PaliGemmaForConditionalGeneration`
* PaliGemma ⚠️, PaliGemma 2 ⚠️
* PaliGemma, PaliGemma 2
* T + I<sup>E</sup>
* `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc.
*
* ✅︎
* ✅︎
* ⚠️
- * `Phi3VForCausalLM`
* Phi-3-Vision, Phi-3.5-Vision
* T + I<sup>E+</sup>
@ -926,34 +926,15 @@ See [this page](#generative-models) for more information on how to use generativ
<sup>E</sup> Pre-computed embeddings can be inputted for this modality.
<sup>+</sup> Multiple items can be inputted per text prompt for this modality.
:::{warning}
vLLM does not currently support PrefixLM attention mask, so our PaliGemma implementation uses regular causal attention, which causes the model output to be unstable.
We may deprecate this model series in a future release.
:::
:::{note}
`h2oai/h2ovl-mississippi-2b` will be available in V1 once we support backends other than FlashAttention.
:::
:::{note}
To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have to pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.
:::
:::{note}
The official `openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (`HwwwH/MiniCPM-V-2`) for now.
For more details, please see: <gh-pr:4087#issuecomment-2250397630>
:::
:::{note}
To use Qwen2.5-VL series models, you have to install Hugging Face Transformers library from source via `pip install git+https://github.com/huggingface/transformers`.
:::
:::{note}
:::{important}
To use Gemma3 series models, you have to install Hugging Face Transformers library from source via
`pip install git+https://github.com/huggingface/transformers`.
The earliest commit that supports this is [`50d3530aa04e7a7d003e6b255a98f79fd0447357`](https://github.com/huggingface/transformers/commit/50d3530aa04e7a7d003e6b255a98f79fd0447357).
Pan-and-scan image pre-processing is currently supported on V0 (but not V1).
You can enable it by passing `--mm-processor-kwargs '{"do_pan_and_scan": True}'`.
:::
:::{warning}
Both V0 and V1 support `Gemma3ForConditionalGeneration` for text-only inputs.
However, there are differences in how they handle text + image inputs:
@ -969,9 +950,23 @@ V1 currently uses a simplified attention pattern:
- Will be updated in the future to support the correct behavior
This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends.
:::
Additionally, vLLM's current Gemma 3 implementation does not support the pan-and-scan image pre-processing algorithm, which helps handle images with skewed aspect ratios by intelligently cropping them into multiple views.
Without this feature, model performance may degrade when processing images that deviate significantly from square dimensions.
:::{note}
`h2oai/h2ovl-mississippi-2b` will be available in V1 once we support backends other than FlashAttention.
:::
:::{note}
To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have to pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.
:::
:::{note}
The official `openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (`HwwwH/MiniCPM-V-2`) for now.
For more details, please see: <gh-pr:4087#issuecomment-2250397630>
:::
:::{warning}
Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1.
:::
### Pooling Models

View File

@ -123,10 +123,14 @@ def run_gemma3(questions: list[str], modality: str):
assert modality == "image"
model_name = "google/gemma-3-4b-it"
llm = LLM(model=model_name,
llm = LLM(
model=model_name,
max_model_len=2048,
max_num_seqs=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
# Default is False; setting it to True is not supported in V1 yet
mm_processor_kwargs={"do_pan_and_scan": True},
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
prompts = [("<bos><start_of_turn>user\n"
f"<start_of_image>{question}<end_of_turn>\n"

View File

@ -83,10 +83,14 @@ def load_deepseek_vl2(question: str, image_urls: list[str]):
def load_gemma3(question, image_urls: list[str]) -> ModelRequestData:
model_name = "google/gemma-3-4b-it"
llm = LLM(model=model_name,
llm = LLM(
model=model_name,
max_model_len=8192,
max_num_seqs=2,
limit_mm_per_prompt={"image": len(image_urls)})
# Default is False; setting it to True is not supported in V1 yet
mm_processor_kwargs={"do_pan_and_scan": True},
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [{

View File

@ -9,7 +9,7 @@ from pathlib import PosixPath
import pytest
from packaging.version import Version
from transformers import AutoModelForVision2Seq
from transformers import AutoModelForPreTraining, AutoModelForVision2Seq
from transformers import __version__ as TRANSFORMERS_VERSION
from vllm.platforms import current_platform
@ -234,6 +234,23 @@ VLM_TEST_SETTINGS = {
num_logprobs=10,
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
),
"gemma3": VLMTestInfo(
models=["google/gemma-3-4b-it"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<bos><start_of_turn>user\n{img_prompt}<end_of_turn>\n<start_of_turn>model\n", # noqa: E501
single_image_prompts=IMAGE_ASSETS.prompts({
"stop_sign": "<start_of_image>What's the content in the center of the image?", # noqa: E501
"cherry_blossom": "<start_of_image>What is the season?", # noqa: E501
}),
multi_image_prompt="<start_of_image><start_of_image>Describe the two images in detail.", # noqa: E501
max_model_len=4096,
max_num_seqs=2,
# TODO: Use AutoModelForVision2Seq once transformers supports this
auto_cls=AutoModelForPreTraining,
dtype="bfloat16",
vllm_runner_kwargs={"mm_processor_kwargs": {"do_pan_and_scan": True}},
patch_hf_runner=model_utils.gemma3_patch_hf_runner,
),
"glm4v": VLMTestInfo(
models=["THUDM/glm-4v-9b"],
test_type=VLMTestType.IMAGE,

View File

@ -304,6 +304,18 @@ def deepseekvl2_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
return hf_model
def gemma3_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for Gemma 3."""
hf_processor = hf_model.processor
def processor(*args, **kwargs):
return hf_processor(*args, do_pan_and_scan=True, **kwargs)
hf_model.processor = processor
return hf_model
def glm_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for GLM4."""
hf_processor = hf_model.processor

View File

@ -348,7 +348,11 @@ class InputRegistry:
dummy_factory = self._get_dummy_data_factory(model_cls)
mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
dummy_factory, overrides=model_config.mm_processor_kwargs)
dummy_factory,
overrides=model_config.mm_processor_kwargs,
requires_kw_only=False,
allow_var_kwargs=True,
)
dummy_data = dummy_factory(InputContext(model_config), seq_len,
_MultiModalCounts(mm_counts),
@ -381,6 +385,7 @@ class InputRegistry:
self,
ctx: InputContext,
inputs: ProcessorInputs,
**kwargs: object,
) -> ProcessorInputs:
"""The default input processor is a no-op."""
return inputs
@ -447,6 +452,8 @@ class InputRegistry:
model_config.mm_processor_kwargs,
inputs.get("mm_processor_kwargs", {}), # type: ignore
processor,
requires_kw_only=False,
allow_var_kwargs=True,
)
processed_inputs = processor(

View File

@ -1,10 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
import math
from typing import (Any, Iterable, Literal, Mapping, Optional, Sequence, Set,
Tuple, TypedDict, Union)
import torch
from torch import nn
from transformers import BatchFeature, Gemma3Config, ProcessorMixin
from transformers import BatchFeature, Gemma3Config, Gemma3Processor
from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
from vllm.config import VllmConfig
from vllm.logger import init_logger
@ -14,10 +16,11 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
PromptUpdate, encode_tokens)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
@ -31,8 +34,15 @@ logger = init_logger(__name__)
class Gemma3ImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
pixel_values: torch.Tensor
"""
Shape: `(num_crops_total, num_channels, height, width)`
`num_crops_total` is the total number of crops
over each image over each prompt in the batch.
"""
num_crops: torch.Tensor
"""Shape: `(batch_size * num_images,)`"""
Gemma3ImageInputs = Gemma3ImagePixelInputs
@ -40,6 +50,9 @@ Gemma3ImageInputs = Gemma3ImagePixelInputs
class Gemma3ProcessingInfo(BaseProcessingInfo):
def get_hf_processor(self, **kwargs: object):
return self.ctx.get_hf_processor(Gemma3Processor, **kwargs)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
@ -48,22 +61,160 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
hf_config = self.ctx.get_hf_config()
return {"image": hf_config.mm_tokens_per_image}
return {"image": self.get_max_image_tokens()}
def _resolve_image_kwargs(
self,
processor: Gemma3Processor,
keys: set[str],
) -> dict[str, Any]:
image_processor = processor.image_processor
kwargs = processor._merge_kwargs(
Gemma3ProcessorKwargs,
tokenizer_init_kwargs=processor.tokenizer.init_kwargs,
)
images_kwargs = kwargs["images_kwargs"]
def _resolve_kw(key: str):
val = getattr(image_processor, key)
if val is None:
val = images_kwargs[key]
return val
return {k: _resolve_kw(k) for k in keys}
def get_num_crops(
self,
*,
image_width: int,
image_height: int,
processor: Optional[Gemma3Processor],
) -> int:
if processor is None:
processor = self.get_hf_processor()
images_kwargs = self._resolve_image_kwargs(
processor, {
"do_pan_and_scan", "pan_and_scan_min_crop_size",
"pan_and_scan_max_num_crops",
"pan_and_scan_min_ratio_to_activate"
})
do_pan_and_scan = images_kwargs["do_pan_and_scan"]
pan_and_scan_min_crop_size = images_kwargs[
"pan_and_scan_min_crop_size"]
pan_and_scan_max_num_crops = images_kwargs[
"pan_and_scan_max_num_crops"]
pan_and_scan_min_ratio_to_activate = images_kwargs[
"pan_and_scan_min_ratio_to_activate"]
if not do_pan_and_scan:
return 0
# Based on Gemma3ImageProcessor.pan_and_scan
if image_width >= image_height:
if image_width / image_height < pan_and_scan_min_ratio_to_activate:
return 0
num_crops_w = min(
int(math.floor(image_width / pan_and_scan_min_crop_size)),
int(math.floor(image_width / image_height + 0.5)),
)
num_crops_w = max(2, num_crops_w)
num_crops_w = min(pan_and_scan_max_num_crops, num_crops_w)
num_crops_h = 1
else:
if image_height / image_width < pan_and_scan_min_ratio_to_activate:
return 0
num_crops_h = min(
int(math.floor(image_height / pan_and_scan_min_crop_size)),
int(math.floor(image_height / image_width + 0.5)),
)
num_crops_h = max(2, num_crops_h)
num_crops_h = min(pan_and_scan_max_num_crops, num_crops_h)
num_crops_w = 1
crop_size_w = int(math.ceil(image_width / num_crops_w))
crop_size_h = int(math.ceil(image_height / num_crops_h))
if min(crop_size_w, crop_size_h) < pan_and_scan_min_crop_size:
return 0
return num_crops_w * num_crops_h
def get_image_repl(
self,
*,
image_width: int,
image_height: int,
processor: Optional[Gemma3Processor],
) -> str:
if processor is None:
processor = self.get_hf_processor()
image_token = processor.boi_token
num_crops = self.get_num_crops(
image_width=image_width,
image_height=image_height,
processor=processor,
)
if num_crops == 0:
image_text = image_token
else:
crops_image_tokens = " ".join(image_token
for _ in range(num_crops))
image_text = (
f"Here is the original image {image_token} and here are some "
f"crops to help you see better {crops_image_tokens}")
return image_text.replace(image_token, processor.full_image_sequence)
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
processor: Optional[ProcessorMixin],
processor: Optional[Gemma3Processor],
) -> int:
hf_config = self.ctx.get_hf_config()
return hf_config.mm_tokens_per_image
tokenizer = self.get_tokenizer()
image_repl = self.get_image_repl(
image_width=image_width,
image_height=image_height,
processor=processor,
)
image_repl_tokens = encode_tokens(
tokenizer,
image_repl,
add_special_tokens=False,
)
return len(image_repl_tokens)
def get_image_size_with_most_features(self) -> ImageSize:
# Result in the max possible feature size (h:w = 16:1)
return ImageSize(height=8000, width=50)
processor = self.get_hf_processor()
images_kwargs = self._resolve_image_kwargs(
processor, {"pan_and_scan_max_num_crops"})
max_num_crops = images_kwargs["pan_and_scan_max_num_crops"]
# Result in the max possible feature size (h:w = max_num_crops:1)
return ImageSize(height=50 * max_num_crops, width=50)
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=None,
)
class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):
@ -73,10 +224,11 @@ class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
tokenizer = self.info.get_tokenizer()
boi_token = tokenizer.boi_token
processor = self.info.get_hf_processor()
image_token = processor.boi_token
num_images = mm_counts.get("image", 0)
target_width, target_height = \
self.info.get_image_size_with_most_features()
@ -86,8 +238,13 @@ class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):
height=target_height,
num_images=num_images)
}
# NOTE: We need to separate the image tokens here because
# encode("\n\n\n\n") != encode("\n\n") * 2, which interferes
# with the detection of prompt updates when the image tokens are
# right next to each other
return ProcessorInputs(
prompt_text=" ".join([boi_token] * num_images),
prompt_text=" ".join([image_token] * num_images),
mm_data=mm_data,
)
@ -100,22 +257,49 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
# TODO(woosuk): Support pan-and-scan.
img_kwargs = mm_kwargs.get("images_kwargs", {})
img_kwargs["do_pan_and_scan"] = False
mm_kwargs["images_kwargs"] = img_kwargs
return super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
processed_outputs = super()._call_hf_processor(
prompt,
mm_data,
mm_kwargs,
)
# HF processor pops the `num_crops` kwarg, which is needed by vLLM
if (images := mm_data.get("images")) is not None:
assert isinstance(images, list)
parsed_images = (self._get_data_parser().parse_mm_data({
"image":
images
}).get_items("image", ImageProcessorItems))
image_sizes = [
parsed_images.get_image_size(i)
for i in range(len(parsed_images))
]
hf_processor = self.info.get_hf_processor(**mm_kwargs)
num_crops = [
self.info.get_num_crops(image_width=size.width,
image_height=size.height,
processor=hf_processor)
for size in image_sizes
]
processed_outputs["num_crops"] = torch.tensor(num_crops)
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
num_crops = hf_inputs.get("num_crops", torch.empty(0))
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", num_crops + 1),
num_crops=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
self,
@ -123,25 +307,23 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]:
tokenizer = self.info.get_tokenizer()
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
hf_config = self.info.get_hf_config()
boi_token = tokenizer.boi_token
image_token = tokenizer.image_token
mm_tokens_per_image = hf_config.mm_tokens_per_image
image_tokens_expanded = "".join([image_token] * mm_tokens_per_image)
image_token = hf_processor.boi_token
def get_replacement_gemma3(item_idx: int):
return PromptUpdateDetails(
full=hf_processor.full_image_sequence,
features=image_tokens_expanded,
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
return self.info.get_image_repl(
image_width=image_size.width,
image_height=image_size.height,
processor=hf_processor,
)
return [
PromptReplacement(
modality="image",
target=boi_token,
target=image_token,
replacement=get_replacement_gemma3,
)
]
@ -254,19 +436,27 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal,
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Gemma3ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
num_crops = kwargs.pop("num_crops", None)
image_embeds = kwargs.pop("image_embeds", None)
assert image_embeds is None, "Gemma3 does not support image_embeds."
if pixel_values is None:
return None
if not isinstance(pixel_values, (torch.Tensor, list[torch.Tensor])):
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if not isinstance(num_crops, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_crops values. "
f"Got type: {type(num_crops)}")
pixel_values = flatten_bn(pixel_values, concat=True)
num_crops = flatten_bn(num_crops, concat=True)
return Gemma3ImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
pixel_values=self._validate_pixel_values(pixel_values),
num_crops=num_crops,
)
def _image_pixels_to_features(
@ -283,7 +473,8 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal,
image_input: Gemma3ImageInputs,
) -> torch.Tensor:
assert self.vision_tower is not None
pixel_values = image_input["data"]
pixel_values = image_input["pixel_values"]
vision_outputs = self._image_pixels_to_features(
self.vision_tower,
pixel_values,

View File

@ -226,7 +226,11 @@ class MultiModalPlugin(ABC):
if callable(max_mm_tokens):
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
max_mm_tokens, overrides=model_config.mm_processor_kwargs)
max_mm_tokens,
overrides=model_config.mm_processor_kwargs,
requires_kw_only=False,
allow_var_kwargs=True,
)
max_mm_tokens = max_mm_tokens(InputContext(model_config),
**mm_processor_kwargs)

View File

@ -1488,11 +1488,11 @@ def get_allowed_kwarg_only_overrides(
if requires_kw_only:
logger.warning(
"The following intended overrides are not keyword-only args "
"and and will be dropped: %s", dropped_keys)
"and will be dropped: %s", dropped_keys)
else:
logger.warning(
"The following intended overrides are not keyword args "
"and and will be dropped: %s", dropped_keys)
"and will be dropped: %s", dropped_keys)
return filtered_overrides