diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md
index 98e7572981de..5db82c8e5567 100644
--- a/docs/source/models/supported_models.md
+++ b/docs/source/models/supported_models.md
@@ -763,7 +763,7 @@ See [this page](#generative-models) for more information on how to use generativ
* `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc.
* ✅︎
* ✅︎
- * ✅︎\*
+ * ⚠️
- * `GLM4VForCausalLM`^
* GLM-4V
* T + I
@@ -856,12 +856,12 @@ See [this page](#generative-models) for more information on how to use generativ
* ✅︎
* ✅︎
- * `PaliGemmaForConditionalGeneration`
- * PaliGemma ⚠️, PaliGemma 2 ⚠️
+ * PaliGemma, PaliGemma 2
* T + IE
* `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc.
*
* ✅︎
- * ✅︎
+ * ⚠️
- * `Phi3VForCausalLM`
* Phi-3-Vision, Phi-3.5-Vision
* T + IE+
@@ -926,34 +926,15 @@ See [this page](#generative-models) for more information on how to use generativ
E Pre-computed embeddings can be inputted for this modality.
+ Multiple items can be inputted per text prompt for this modality.
-:::{warning}
-vLLM does not currently support PrefixLM attention mask, so our PaliGemma implementation uses regular causal attention, which causes the model output to be unstable.
-
-We may deprecate this model series in a future release.
-:::
-
-:::{note}
-`h2oai/h2ovl-mississippi-2b` will be available in V1 once we support backends other than FlashAttention.
-:::
-
-:::{note}
-To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have to pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.
-:::
-
-:::{note}
-The official `openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (`HwwwH/MiniCPM-V-2`) for now.
-For more details, please see:
-:::
-
-:::{note}
-To use Qwen2.5-VL series models, you have to install Hugging Face Transformers library from source via `pip install git+https://github.com/huggingface/transformers`.
-:::
-
-:::{note}
+:::{important}
To use Gemma3 series models, you have to install Hugging Face Transformers library from source via
`pip install git+https://github.com/huggingface/transformers`.
-The earliest commit that supports this is [`50d3530aa04e7a7d003e6b255a98f79fd0447357`](https://github.com/huggingface/transformers/commit/50d3530aa04e7a7d003e6b255a98f79fd0447357).
+Pan-and-scan image pre-processing is currently supported on V0 (but not V1).
+You can enable it by passing `--mm-processor-kwargs '{"do_pan_and_scan": True}'`.
+:::
+
+:::{warning}
Both V0 and V1 support `Gemma3ForConditionalGeneration` for text-only inputs.
However, there are differences in how they handle text + image inputs:
@@ -969,9 +950,23 @@ V1 currently uses a simplified attention pattern:
- Will be updated in the future to support the correct behavior
This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends.
+:::
-Additionally, vLLM's current Gemma 3 implementation does not support the pan-and-scan image pre-processing algorithm, which helps handle images with skewed aspect ratios by intelligently cropping them into multiple views.
-Without this feature, model performance may degrade when processing images that deviate significantly from square dimensions.
+:::{note}
+`h2oai/h2ovl-mississippi-2b` will be available in V1 once we support backends other than FlashAttention.
+:::
+
+:::{note}
+To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have to pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.
+:::
+
+:::{note}
+The official `openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (`HwwwH/MiniCPM-V-2`) for now.
+For more details, please see:
+:::
+
+:::{warning}
+Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1.
:::
### Pooling Models
diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py
index 39acab4765a3..432cda5e2439 100644
--- a/examples/offline_inference/vision_language.py
+++ b/examples/offline_inference/vision_language.py
@@ -123,10 +123,14 @@ def run_gemma3(questions: list[str], modality: str):
assert modality == "image"
model_name = "google/gemma-3-4b-it"
- llm = LLM(model=model_name,
- max_model_len=2048,
- max_num_seqs=2,
- disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
+ llm = LLM(
+ model=model_name,
+ max_model_len=2048,
+ max_num_seqs=2,
+ # Default is False; setting it to True is not supported in V1 yet
+ mm_processor_kwargs={"do_pan_and_scan": True},
+ disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
+ )
prompts = [("user\n"
f"{question}\n"
diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py
index 4963e6a8c4e7..b47004aa9615 100644
--- a/examples/offline_inference/vision_language_multi_image.py
+++ b/examples/offline_inference/vision_language_multi_image.py
@@ -83,10 +83,14 @@ def load_deepseek_vl2(question: str, image_urls: list[str]):
def load_gemma3(question, image_urls: list[str]) -> ModelRequestData:
model_name = "google/gemma-3-4b-it"
- llm = LLM(model=model_name,
- max_model_len=8192,
- max_num_seqs=2,
- limit_mm_per_prompt={"image": len(image_urls)})
+ llm = LLM(
+ model=model_name,
+ max_model_len=8192,
+ max_num_seqs=2,
+ # Default is False; setting it to True is not supported in V1 yet
+ mm_processor_kwargs={"do_pan_and_scan": True},
+ limit_mm_per_prompt={"image": len(image_urls)},
+ )
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [{
diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py
index 2540933bbc23..880d1bd1dc4e 100644
--- a/tests/models/decoder_only/vision_language/test_models.py
+++ b/tests/models/decoder_only/vision_language/test_models.py
@@ -9,7 +9,7 @@ from pathlib import PosixPath
import pytest
from packaging.version import Version
-from transformers import AutoModelForVision2Seq
+from transformers import AutoModelForPreTraining, AutoModelForVision2Seq
from transformers import __version__ as TRANSFORMERS_VERSION
from vllm.platforms import current_platform
@@ -234,6 +234,23 @@ VLM_TEST_SETTINGS = {
num_logprobs=10,
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
),
+ "gemma3": VLMTestInfo(
+ models=["google/gemma-3-4b-it"],
+ test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
+ prompt_formatter=lambda img_prompt: f"user\n{img_prompt}\nmodel\n", # noqa: E501
+ single_image_prompts=IMAGE_ASSETS.prompts({
+ "stop_sign": "What's the content in the center of the image?", # noqa: E501
+ "cherry_blossom": "What is the season?", # noqa: E501
+ }),
+ multi_image_prompt="Describe the two images in detail.", # noqa: E501
+ max_model_len=4096,
+ max_num_seqs=2,
+ # TODO: Use AutoModelForVision2Seq once transformers supports this
+ auto_cls=AutoModelForPreTraining,
+ dtype="bfloat16",
+ vllm_runner_kwargs={"mm_processor_kwargs": {"do_pan_and_scan": True}},
+ patch_hf_runner=model_utils.gemma3_patch_hf_runner,
+ ),
"glm4v": VLMTestInfo(
models=["THUDM/glm-4v-9b"],
test_type=VLMTestType.IMAGE,
diff --git a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py
index 66410f66ca0d..5e1fcfd8f082 100644
--- a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py
+++ b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py
@@ -304,6 +304,18 @@ def deepseekvl2_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
return hf_model
+def gemma3_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
+ """Patches and returns an instance of the HfRunner to use for Gemma 3."""
+ hf_processor = hf_model.processor
+
+ def processor(*args, **kwargs):
+ return hf_processor(*args, do_pan_and_scan=True, **kwargs)
+
+ hf_model.processor = processor
+
+ return hf_model
+
+
def glm_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for GLM4."""
hf_processor = hf_model.processor
diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py
index a0bd8f278fd0..b6ceb5fb82d7 100644
--- a/vllm/inputs/registry.py
+++ b/vllm/inputs/registry.py
@@ -348,7 +348,11 @@ class InputRegistry:
dummy_factory = self._get_dummy_data_factory(model_cls)
mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
- dummy_factory, overrides=model_config.mm_processor_kwargs)
+ dummy_factory,
+ overrides=model_config.mm_processor_kwargs,
+ requires_kw_only=False,
+ allow_var_kwargs=True,
+ )
dummy_data = dummy_factory(InputContext(model_config), seq_len,
_MultiModalCounts(mm_counts),
@@ -381,6 +385,7 @@ class InputRegistry:
self,
ctx: InputContext,
inputs: ProcessorInputs,
+ **kwargs: object,
) -> ProcessorInputs:
"""The default input processor is a no-op."""
return inputs
@@ -447,6 +452,8 @@ class InputRegistry:
model_config.mm_processor_kwargs,
inputs.get("mm_processor_kwargs", {}), # type: ignore
processor,
+ requires_kw_only=False,
+ allow_var_kwargs=True,
)
processed_inputs = processor(
diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py
index 121aee51786b..ac80059cbe6d 100644
--- a/vllm/model_executor/models/gemma3_mm.py
+++ b/vllm/model_executor/models/gemma3_mm.py
@@ -1,10 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
+import math
from typing import (Any, Iterable, Literal, Mapping, Optional, Sequence, Set,
Tuple, TypedDict, Union)
import torch
from torch import nn
-from transformers import BatchFeature, Gemma3Config, ProcessorMixin
+from transformers import BatchFeature, Gemma3Config, Gemma3Processor
+from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
from vllm.config import VllmConfig
from vllm.logger import init_logger
@@ -14,10 +16,11 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
-from vllm.multimodal.parse import ImageSize, MultiModalDataItems
+from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
+ MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
- PromptUpdate, PromptUpdateDetails)
+ PromptUpdate, encode_tokens)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
@@ -31,8 +34,15 @@ logger = init_logger(__name__)
class Gemma3ImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
- data: torch.Tensor
- """Shape: `(batch_size * num_images, num_channels, height, width)`"""
+ pixel_values: torch.Tensor
+ """
+ Shape: `(num_crops_total, num_channels, height, width)`
+
+ `num_crops_total` is the total number of crops
+ over each image over each prompt in the batch.
+ """
+ num_crops: torch.Tensor
+ """Shape: `(batch_size * num_images,)`"""
Gemma3ImageInputs = Gemma3ImagePixelInputs
@@ -40,6 +50,9 @@ Gemma3ImageInputs = Gemma3ImagePixelInputs
class Gemma3ProcessingInfo(BaseProcessingInfo):
+ def get_hf_processor(self, **kwargs: object):
+ return self.ctx.get_hf_processor(Gemma3Processor, **kwargs)
+
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
@@ -48,22 +61,160 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
- hf_config = self.ctx.get_hf_config()
- return {"image": hf_config.mm_tokens_per_image}
+ return {"image": self.get_max_image_tokens()}
+
+ def _resolve_image_kwargs(
+ self,
+ processor: Gemma3Processor,
+ keys: set[str],
+ ) -> dict[str, Any]:
+ image_processor = processor.image_processor
+ kwargs = processor._merge_kwargs(
+ Gemma3ProcessorKwargs,
+ tokenizer_init_kwargs=processor.tokenizer.init_kwargs,
+ )
+
+ images_kwargs = kwargs["images_kwargs"]
+
+ def _resolve_kw(key: str):
+ val = getattr(image_processor, key)
+ if val is None:
+ val = images_kwargs[key]
+
+ return val
+
+ return {k: _resolve_kw(k) for k in keys}
+
+ def get_num_crops(
+ self,
+ *,
+ image_width: int,
+ image_height: int,
+ processor: Optional[Gemma3Processor],
+ ) -> int:
+ if processor is None:
+ processor = self.get_hf_processor()
+
+ images_kwargs = self._resolve_image_kwargs(
+ processor, {
+ "do_pan_and_scan", "pan_and_scan_min_crop_size",
+ "pan_and_scan_max_num_crops",
+ "pan_and_scan_min_ratio_to_activate"
+ })
+
+ do_pan_and_scan = images_kwargs["do_pan_and_scan"]
+ pan_and_scan_min_crop_size = images_kwargs[
+ "pan_and_scan_min_crop_size"]
+ pan_and_scan_max_num_crops = images_kwargs[
+ "pan_and_scan_max_num_crops"]
+ pan_and_scan_min_ratio_to_activate = images_kwargs[
+ "pan_and_scan_min_ratio_to_activate"]
+
+ if not do_pan_and_scan:
+ return 0
+
+ # Based on Gemma3ImageProcessor.pan_and_scan
+ if image_width >= image_height:
+ if image_width / image_height < pan_and_scan_min_ratio_to_activate:
+ return 0
+
+ num_crops_w = min(
+ int(math.floor(image_width / pan_and_scan_min_crop_size)),
+ int(math.floor(image_width / image_height + 0.5)),
+ )
+
+ num_crops_w = max(2, num_crops_w)
+ num_crops_w = min(pan_and_scan_max_num_crops, num_crops_w)
+ num_crops_h = 1
+ else:
+ if image_height / image_width < pan_and_scan_min_ratio_to_activate:
+ return 0
+
+ num_crops_h = min(
+ int(math.floor(image_height / pan_and_scan_min_crop_size)),
+ int(math.floor(image_height / image_width + 0.5)),
+ )
+
+ num_crops_h = max(2, num_crops_h)
+ num_crops_h = min(pan_and_scan_max_num_crops, num_crops_h)
+ num_crops_w = 1
+
+ crop_size_w = int(math.ceil(image_width / num_crops_w))
+ crop_size_h = int(math.ceil(image_height / num_crops_h))
+
+ if min(crop_size_w, crop_size_h) < pan_and_scan_min_crop_size:
+ return 0
+
+ return num_crops_w * num_crops_h
+
+ def get_image_repl(
+ self,
+ *,
+ image_width: int,
+ image_height: int,
+ processor: Optional[Gemma3Processor],
+ ) -> str:
+ if processor is None:
+ processor = self.get_hf_processor()
+
+ image_token = processor.boi_token
+
+ num_crops = self.get_num_crops(
+ image_width=image_width,
+ image_height=image_height,
+ processor=processor,
+ )
+
+ if num_crops == 0:
+ image_text = image_token
+ else:
+ crops_image_tokens = " ".join(image_token
+ for _ in range(num_crops))
+ image_text = (
+ f"Here is the original image {image_token} and here are some "
+ f"crops to help you see better {crops_image_tokens}")
+
+ return image_text.replace(image_token, processor.full_image_sequence)
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
- processor: Optional[ProcessorMixin],
+ processor: Optional[Gemma3Processor],
) -> int:
- hf_config = self.ctx.get_hf_config()
- return hf_config.mm_tokens_per_image
+ tokenizer = self.get_tokenizer()
+ image_repl = self.get_image_repl(
+ image_width=image_width,
+ image_height=image_height,
+ processor=processor,
+ )
+
+ image_repl_tokens = encode_tokens(
+ tokenizer,
+ image_repl,
+ add_special_tokens=False,
+ )
+ return len(image_repl_tokens)
def get_image_size_with_most_features(self) -> ImageSize:
- # Result in the max possible feature size (h:w = 16:1)
- return ImageSize(height=8000, width=50)
+ processor = self.get_hf_processor()
+
+ images_kwargs = self._resolve_image_kwargs(
+ processor, {"pan_and_scan_max_num_crops"})
+ max_num_crops = images_kwargs["pan_and_scan_max_num_crops"]
+
+ # Result in the max possible feature size (h:w = max_num_crops:1)
+ return ImageSize(height=50 * max_num_crops, width=50)
+
+ def get_max_image_tokens(self) -> int:
+ target_width, target_height = self.get_image_size_with_most_features()
+
+ return self.get_num_image_tokens(
+ image_width=target_width,
+ image_height=target_height,
+ processor=None,
+ )
class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):
@@ -73,10 +224,11 @@ class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
- tokenizer = self.info.get_tokenizer()
- boi_token = tokenizer.boi_token
+ processor = self.info.get_hf_processor()
+ image_token = processor.boi_token
num_images = mm_counts.get("image", 0)
+
target_width, target_height = \
self.info.get_image_size_with_most_features()
@@ -86,8 +238,13 @@ class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):
height=target_height,
num_images=num_images)
}
+
+ # NOTE: We need to separate the image tokens here because
+ # encode("\n\n\n\n") != encode("\n\n") * 2, which interferes
+ # with the detection of prompt updates when the image tokens are
+ # right next to each other
return ProcessorInputs(
- prompt_text=" ".join([boi_token] * num_images),
+ prompt_text=" ".join([image_token] * num_images),
mm_data=mm_data,
)
@@ -100,22 +257,49 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
- # TODO(woosuk): Support pan-and-scan.
- img_kwargs = mm_kwargs.get("images_kwargs", {})
- img_kwargs["do_pan_and_scan"] = False
- mm_kwargs["images_kwargs"] = img_kwargs
- return super()._call_hf_processor(
- prompt=prompt,
- mm_data=mm_data,
- mm_kwargs=mm_kwargs,
+ processed_outputs = super()._call_hf_processor(
+ prompt,
+ mm_data,
+ mm_kwargs,
)
+ # HF processor pops the `num_crops` kwarg, which is needed by vLLM
+ if (images := mm_data.get("images")) is not None:
+ assert isinstance(images, list)
+
+ parsed_images = (self._get_data_parser().parse_mm_data({
+ "image":
+ images
+ }).get_items("image", ImageProcessorItems))
+ image_sizes = [
+ parsed_images.get_image_size(i)
+ for i in range(len(parsed_images))
+ ]
+ hf_processor = self.info.get_hf_processor(**mm_kwargs)
+
+ num_crops = [
+ self.info.get_num_crops(image_width=size.width,
+ image_height=size.height,
+ processor=hf_processor)
+ for size in image_sizes
+ ]
+
+ processed_outputs["num_crops"] = torch.tensor(num_crops)
+
+ return processed_outputs
+
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
- return dict(pixel_values=MultiModalFieldConfig.batched("image"))
+ num_crops = hf_inputs.get("num_crops", torch.empty(0))
+
+ return dict(
+ pixel_values=MultiModalFieldConfig.flat_from_sizes(
+ "image", num_crops + 1),
+ num_crops=MultiModalFieldConfig.batched("image"),
+ )
def _get_prompt_updates(
self,
@@ -123,25 +307,23 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]:
- tokenizer = self.info.get_tokenizer()
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
- hf_config = self.info.get_hf_config()
-
- boi_token = tokenizer.boi_token
- image_token = tokenizer.image_token
- mm_tokens_per_image = hf_config.mm_tokens_per_image
- image_tokens_expanded = "".join([image_token] * mm_tokens_per_image)
+ image_token = hf_processor.boi_token
def get_replacement_gemma3(item_idx: int):
- return PromptUpdateDetails(
- full=hf_processor.full_image_sequence,
- features=image_tokens_expanded,
+ images = mm_items.get_items("image", ImageProcessorItems)
+
+ image_size = images.get_image_size(item_idx)
+ return self.info.get_image_repl(
+ image_width=image_size.width,
+ image_height=image_size.height,
+ processor=hf_processor,
)
return [
PromptReplacement(
modality="image",
- target=boi_token,
+ target=image_token,
replacement=get_replacement_gemma3,
)
]
@@ -254,19 +436,27 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal,
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Gemma3ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
+ num_crops = kwargs.pop("num_crops", None)
image_embeds = kwargs.pop("image_embeds", None)
assert image_embeds is None, "Gemma3 does not support image_embeds."
if pixel_values is None:
return None
- if not isinstance(pixel_values, (torch.Tensor, list[torch.Tensor])):
+ if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
+ if not isinstance(num_crops, (torch.Tensor, list)):
+ raise ValueError("Incorrect type of num_crops values. "
+ f"Got type: {type(num_crops)}")
+
pixel_values = flatten_bn(pixel_values, concat=True)
+ num_crops = flatten_bn(num_crops, concat=True)
+
return Gemma3ImagePixelInputs(
type="pixel_values",
- data=self._validate_pixel_values(pixel_values),
+ pixel_values=self._validate_pixel_values(pixel_values),
+ num_crops=num_crops,
)
def _image_pixels_to_features(
@@ -283,7 +473,8 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal,
image_input: Gemma3ImageInputs,
) -> torch.Tensor:
assert self.vision_tower is not None
- pixel_values = image_input["data"]
+
+ pixel_values = image_input["pixel_values"]
vision_outputs = self._image_pixels_to_features(
self.vision_tower,
pixel_values,
diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py
index e0b160a65047..5159b0bca8c1 100644
--- a/vllm/multimodal/base.py
+++ b/vllm/multimodal/base.py
@@ -226,7 +226,11 @@ class MultiModalPlugin(ABC):
if callable(max_mm_tokens):
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
- max_mm_tokens, overrides=model_config.mm_processor_kwargs)
+ max_mm_tokens,
+ overrides=model_config.mm_processor_kwargs,
+ requires_kw_only=False,
+ allow_var_kwargs=True,
+ )
max_mm_tokens = max_mm_tokens(InputContext(model_config),
**mm_processor_kwargs)
diff --git a/vllm/utils.py b/vllm/utils.py
index 9cad2b8854a2..a8eba27dbcdb 100644
--- a/vllm/utils.py
+++ b/vllm/utils.py
@@ -1488,11 +1488,11 @@ def get_allowed_kwarg_only_overrides(
if requires_kw_only:
logger.warning(
"The following intended overrides are not keyword-only args "
- "and and will be dropped: %s", dropped_keys)
+ "and will be dropped: %s", dropped_keys)
else:
logger.warning(
"The following intended overrides are not keyword args "
- "and and will be dropped: %s", dropped_keys)
+ "and will be dropped: %s", dropped_keys)
return filtered_overrides