diff --git a/docs/models/hardware_supported_models/tpu.md b/docs/models/hardware_supported_models/tpu.md
index 7b0a5ba6e72d..8d3e28c259ec 100644
--- a/docs/models/hardware_supported_models/tpu.md
+++ b/docs/models/hardware_supported_models/tpu.md
@@ -16,8 +16,8 @@
| meta-llama/Llama-4-* | Llama4ForConditionalGeneration | ❌ |
| microsoft/Phi-3-mini-128k-instruct | Phi3ForCausalLM | 🟨 |
| microsoft/phi-4 | Phi3ForCausalLM | ❌ |
-| google/gemma-3-27b-it | Gemma3ForConditionalGeneration | 🟨 |
-| google/gemma-3-4b-it | Gemma3ForConditionalGeneration | ❌ |
+| google/gemma-3-27b-it | TransformersForMultimodalLM | 🟨 |
+| google/gemma-3-4b-it | TransformersForMultimodalLM | ❌ |
| deepseek-ai/DeepSeek-R1 | DeepseekV3ForCausalLM | ❌ |
| deepseek-ai/DeepSeek-V3 | DeepseekV3ForCausalLM | ❌ |
| RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8 | LlamaForCausalLM | ✅ |
diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index f98b06188ede..3bba63dda03d 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -650,7 +650,6 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `DeepseekVLV2ForCausalLM`^ | DeepSeek-VL2 | T + I+ | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | ✅︎ |
| `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I+/ V+ | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ |
| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ |
-| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I+ | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ |
| `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | |
| `GLM4VForCausalLM`^ | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ |
| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + IE+ + VE+ | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ |
@@ -679,7 +678,6 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `NVLM_D_Model` | NVLM-D 1.0 | T + I+ | `nvidia/NVLM-D-72B`, etc. | | ✅︎ |
| `Ovis` | Ovis2, Ovis1.6 | T + I+ | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ |
| `Ovis2_5` | Ovis2.5 | T + I+ + V | `AIDC-AI/Ovis2.5-9B`, etc. | | |
-| `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + IE | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ |
| `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + IE+ | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ |
| `Phi4MMForCausalLM` | Phi-4-multimodal | T + I+ / T + A+ / I+ + A+ | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ |
| `Phi4MultimodalForCausalLM` | Phi-4-multimodal (HF Transformers) | T + I+ / T + A+ / I+ + A+ | `microsoft/Phi-4-multimodal-instruct` (with revision `refs/pr/70`), etc. | ✅︎ | ✅︎ |
@@ -704,6 +702,8 @@ Some models are supported only via the [Transformers backend](#transformers). Th
| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
|--------------|--------|--------|-------------------|-----------------------------|-----------------------------------------|
| `Emu3ForConditionalGeneration` | Emu3 | T + I | `BAAI/Emu3-Chat-hf` | ✅︎ | ✅︎ |
+| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I+ | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ |
+| `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + IE | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | ✅︎ | ✅︎ |
^ You need to set the architecture name via `--hf-overrides` to match the one in vLLM.
• For example, to use DeepSeek-VL2 series models:
@@ -712,21 +712,7 @@ Some models are supported only via the [Transformers backend](#transformers). Th
+ Multiple items can be inputted per text prompt for this modality.
!!! warning
- Both V0 and V1 support `Gemma3ForConditionalGeneration` for text-only inputs.
- However, there are differences in how they handle text + image inputs:
-
- V0 correctly implements the model's attention pattern:
- - Uses bidirectional attention between the image tokens corresponding to the same image
- - Uses causal attention for other tokens
- - Implemented via (naive) PyTorch SDPA with masking tensors
- - Note: May use significant memory for long prompts with image
-
- V1 currently uses a simplified attention pattern:
- - Uses causal attention for all tokens, including image tokens
- - Generates reasonable outputs but does not match the original model's attention for text + image inputs, especially when `{"do_pan_and_scan": true}`
- - Will be updated in the future to support the correct behavior
-
- This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends.
+ For `Gemma3ForConditionalGeneration`, `{"do_pan_and_scan": true}` is not supported in Transformers backend yet.
!!! note
`Gemma3nForConditionalGeneration` is only supported on V1 due to shared KV caching and it depends on `timm>=1.0.17` to make use of its
@@ -778,9 +764,6 @@ Some models are supported only via the [Transformers backend](#transformers). Th
The official `openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (`HwwwH/MiniCPM-V-2`) for now.
For more details, please see:
-!!! warning
- Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1.
-
!!! note
For Qwen2.5-Omni and Qwen3-Omni, reading audio from video pre-processing (`--mm-processor-kwargs '{"use_audio_in_video": true}'`) is currently work in progress and not yet supported.
diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py
index 1f09dabaf74c..a92304837be1 100644
--- a/examples/offline_inference/vision_language.py
+++ b/examples/offline_inference/vision_language.py
@@ -248,7 +248,8 @@ def run_gemma3(questions: list[str], modality: str) -> ModelRequestData:
model=model_name,
max_model_len=2048,
max_num_seqs=2,
- mm_processor_kwargs={"do_pan_and_scan": True},
+ # TODO: Support this in transformers backend
+ # mm_processor_kwargs={"do_pan_and_scan": True},
limit_mm_per_prompt={modality: 1},
)
diff --git a/tests/models/language/generation/test_gemma.py b/tests/models/language/generation/test_gemma.py
index 246b893be315..5108da68cb0b 100644
--- a/tests/models/language/generation/test_gemma.py
+++ b/tests/models/language/generation/test_gemma.py
@@ -3,7 +3,7 @@
import numpy as np
import pytest
-MODELS = ["google/gemma-2b", "google/gemma-2-2b", "google/gemma-3-4b-it"]
+MODELS = ["google/gemma-2b", "google/gemma-2-2b"]
@pytest.mark.parametrize("model", MODELS)
@@ -14,14 +14,8 @@ def test_dummy_loader(vllm_runner, monkeypatch, model: str) -> None:
model,
load_format="dummy",
) as llm:
- if model == "google/gemma-3-4b-it":
- normalizers = llm.llm.collective_rpc(
- lambda self: self.model_runner.model.language_model.model.normalizer.cpu().item() # noqa: E501
- )
- config = llm.llm.llm_engine.model_config.hf_config.text_config
- else:
- normalizers = llm.llm.collective_rpc(
- lambda self: self.model_runner.model.model.normalizer.cpu().item()
- )
- config = llm.llm.llm_engine.model_config.hf_config
+ normalizers = llm.apply_model(
+ lambda model: model.model.normalizer.cpu().item()
+ )
+ config = llm.llm.llm_engine.model_config.hf_config
assert np.allclose(normalizers, config.hidden_size**0.5, rtol=2e-3)
diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py
index 7baca7800bb9..d94a3d5cf3c4 100644
--- a/tests/models/multimodal/generation/test_common.py
+++ b/tests/models/multimodal/generation/test_common.py
@@ -113,25 +113,6 @@ VLM_TEST_SETTINGS = {
dtype="bfloat16" if current_platform.is_cpu() else "auto",
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
),
- "paligemma": VLMTestInfo(
- models=["google/paligemma-3b-mix-224"],
- test_type=VLMTestType.IMAGE,
- prompt_formatter=identity,
- img_idx_to_prompt=lambda idx: "",
- # Paligemma uses its own sample prompts because the default one fails
- single_image_prompts=IMAGE_ASSETS.prompts(
- {
- "stop_sign": "caption es",
- "cherry_blossom": "What is in the picture?",
- }
- ),
- auto_cls=AutoModelForImageTextToText,
- vllm_output_post_proc=model_utils.paligemma_vllm_to_hf_output,
- dtype="bfloat16",
- marks=[
- pytest.mark.skip(reason="vLLM does not support PrefixLM attention mask")
- ],
- ),
"qwen2_5_vl": VLMTestInfo(
models=["Qwen/Qwen2.5-VL-3B-Instruct"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE, VLMTestType.VIDEO),
@@ -196,14 +177,24 @@ VLM_TEST_SETTINGS = {
# Gemma3 has bidirectional mask on images
"gemma3-transformers": VLMTestInfo(
models=["google/gemma-3-4b-it"],
- test_type=VLMTestType.IMAGE,
- prompt_formatter=lambda vid_prompt: f"<'user\n{vid_prompt}\nmodel\n", # noqa: E501
- max_model_len=4096,
+ test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
+ prompt_formatter=lambda img_prompt: f"user\n{img_prompt}\nmodel\n", # noqa: E501
+ single_image_prompts=IMAGE_ASSETS.prompts(
+ {
+ "stop_sign": "What's the content in the center of the image?", # noqa: E501
+ "cherry_blossom": "What is the season?",
+ }
+ ),
+ multi_image_prompt="Describe the two images in detail.", # noqa: E501
+ max_model_len=8192,
auto_cls=AutoModelForImageTextToText,
+ # TODO: Support `do_pan_and_scan` in transformers backend
+ # patch_hf_runner=model_utils.gemma3_patch_hf_runner,
vllm_output_post_proc=model_utils.gemma3_vllm_to_hf_output,
image_size_factors=[(0.25, 0.5, 1.0)],
vllm_runner_kwargs={
"model_impl": "transformers",
+ # "mm_processor_kwargs": {"do_pan_and_scan": True},
},
marks=[pytest.mark.core_model],
),
@@ -222,6 +213,27 @@ VLM_TEST_SETTINGS = {
},
marks=[pytest.mark.core_model],
),
+ # PaliGemma has PrefixLM attention
+ "paligemma-transformers": VLMTestInfo(
+ models=["google/paligemma-3b-mix-224"],
+ test_type=VLMTestType.IMAGE,
+ prompt_formatter=identity,
+ img_idx_to_prompt=lambda idx: "",
+ # PaliGemma uses its own sample prompts because the default one fails
+ single_image_prompts=IMAGE_ASSETS.prompts(
+ {
+ "stop_sign": "caption es",
+ "cherry_blossom": "What is in the picture?",
+ }
+ ),
+ auto_cls=AutoModelForImageTextToText,
+ vllm_output_post_proc=model_utils.paligemma_vllm_to_hf_output,
+ image_size_factors=[(0.25, 0.5, 1.0)],
+ vllm_runner_kwargs={
+ "model_impl": "transformers",
+ },
+ marks=[pytest.mark.core_model],
+ ),
# Pixel values from processor are not 4D or 5D arrays
"qwen2_5_vl-transformers": VLMTestInfo(
models=["Qwen/Qwen2.5-VL-3B-Instruct"],
@@ -348,24 +360,6 @@ VLM_TEST_SETTINGS = {
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
marks=[large_gpu_mark(min_gb=32)],
),
- "gemma3": VLMTestInfo(
- models=["google/gemma-3-4b-it"],
- test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
- prompt_formatter=lambda img_prompt: f"user\n{img_prompt}\nmodel\n", # noqa: E501
- single_image_prompts=IMAGE_ASSETS.prompts(
- {
- "stop_sign": "What's the content in the center of the image?", # noqa: E501
- "cherry_blossom": "What is the season?",
- }
- ),
- multi_image_prompt="Describe the two images in detail.", # noqa: E501
- max_model_len=4096,
- max_num_seqs=2,
- auto_cls=AutoModelForImageTextToText,
- vllm_runner_kwargs={"mm_processor_kwargs": {"do_pan_and_scan": True}},
- patch_hf_runner=model_utils.gemma3_patch_hf_runner,
- num_logprobs=10,
- ),
"glm4v": VLMTestInfo(
models=["zai-org/glm-4v-9b"],
test_type=VLMTestType.IMAGE,
diff --git a/tests/models/multimodal/generation/vlm_utils/model_utils.py b/tests/models/multimodal/generation/vlm_utils/model_utils.py
index c110f5598bee..832954258485 100644
--- a/tests/models/multimodal/generation/vlm_utils/model_utils.py
+++ b/tests/models/multimodal/generation/vlm_utils/model_utils.py
@@ -328,16 +328,6 @@ def gemma3_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
hf_model.processor = processor
- orig_generate = hf_model.model.generate
-
- def _generate(self, *args, **kwargs):
- # FIXME: https://github.com/huggingface/transformers/issues/38333
- kwargs["disable_compile"] = True
-
- return orig_generate(*args, **kwargs)
-
- hf_model.model.generate = types.MethodType(_generate, hf_model.model)
-
return hf_model
diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py
index 23f183e1d5bb..78bd284b565f 100644
--- a/tests/models/multimodal/processing/test_common.py
+++ b/tests/models/multimodal/processing/test_common.py
@@ -222,7 +222,6 @@ def _test_processing_correctness(
_ADD_SPECIAL_TOKENS_OVERRIDES = {
"ovis": False,
"ovis2_5": False,
- "paligemma": False,
"ultravox": False,
"whisper": False,
}
@@ -333,7 +332,6 @@ def _test_processing_correctness_one(
"deepseek-ai/deepseek-vl2-tiny",
"baidu/ERNIE-4.5-VL-28B-A3B-PT",
"adept/fuyu-8b",
- "google/gemma-3-4b-it",
"google/gemma-3n-E2B-it",
"zai-org/glm-4v-9b",
"zai-org/GLM-4.1V-9B-Thinking",
@@ -370,8 +368,6 @@ def _test_processing_correctness_one(
"AIDC-AI/Ovis1.6-Llama3.2-3B",
"AIDC-AI/Ovis2-1B",
"AIDC-AI/Ovis2.5-2B",
- "google/paligemma-3b-mix-224",
- "google/paligemma2-3b-ft-docci-448",
"microsoft/Phi-3.5-vision-instruct",
"microsoft/Phi-4-multimodal-instruct",
"mistralai/Pixtral-12B-2409",
diff --git a/tests/models/multimodal/processing/test_tensor_schema.py b/tests/models/multimodal/processing/test_tensor_schema.py
index 00c46082df66..166709329a2c 100644
--- a/tests/models/multimodal/processing/test_tensor_schema.py
+++ b/tests/models/multimodal/processing/test_tensor_schema.py
@@ -48,7 +48,6 @@ ARCH_NEEDS_EXTRAS = [
"Idefics3ForConditionalGeneration",
"LlavaForConditionalGeneration",
"MiniCPMV",
- "PaliGemmaForConditionalGeneration",
]
REPO_ID_TO_SKIP = {
"nm-testing/pixtral-12b-FP8-dynamic": "duplicated test",
diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py
deleted file mode 100644
index 7c628fe93ce3..000000000000
--- a/vllm/model_executor/models/gemma3_mm.py
+++ /dev/null
@@ -1,710 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import math
-from collections.abc import Iterable, Mapping, Sequence
-from typing import Annotated, Any, Literal
-
-import torch
-from torch import nn
-from transformers import BatchFeature, Gemma3Config, Gemma3Processor
-from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
-
-import vllm.envs as envs
-from vllm.config import VllmConfig
-from vllm.config.multimodal import BaseDummyOptions
-from vllm.logger import init_logger
-from vllm.model_executor.layers.layernorm import GemmaRMSNorm
-from vllm.model_executor.models.module_mapping import MultiModelKeys
-from vllm.multimodal import MULTIMODAL_REGISTRY
-from vllm.multimodal.inputs import (
- MultiModalDataDict,
- MultiModalFieldConfig,
- MultiModalKwargsItems,
-)
-from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
-from vllm.multimodal.processing import (
- BaseMultiModalProcessor,
- BaseProcessingInfo,
- MultiModalPromptUpdates,
- MultiModalPromptUpdatesApplyResult,
- PlaceholderFeaturesInfo,
- PromptReplacement,
- PromptUpdate,
- PromptUpdateDetails,
- replace_token_matches,
-)
-from vllm.multimodal.profiling import BaseDummyInputsBuilder
-from vllm.sequence import IntermediateTensors
-from vllm.utils.tensor_schema import TensorSchema, TensorShape
-
-from .interfaces import (
- MultiModalEmbeddings,
- SupportsLoRA,
- SupportsMultiModal,
- SupportsPP,
-)
-from .siglip import SiglipVisionModel
-from .utils import (
- AutoWeightsLoader,
- WeightsMapper,
- init_vllm_registered_model,
- maybe_prefix,
-)
-
-logger = init_logger(__name__)
-
-
-class Gemma3ImagePixelInputs(TensorSchema):
- """
- Dimensions:
- - p: Number of patches total (over each image over each prompt in the
- batch)
- - c: Number of channels (3)
- - h: Height of each patch
- - w: Width of each patch
- - bn: Batch size * number of images
- """
-
- type: Literal["pixel_values"] = "pixel_values"
-
- pixel_values: Annotated[torch.Tensor, TensorShape("p", 3, "h", "w")]
-
- num_patches: Annotated[torch.Tensor, TensorShape("bn")]
-
-
-Gemma3ImageInputs = Gemma3ImagePixelInputs
-
-
-class Gemma3ProcessingInfo(BaseProcessingInfo):
- def get_hf_config(self):
- return self.ctx.get_hf_config(Gemma3Config)
-
- def get_hf_processor(self, **kwargs: object):
- return self.ctx.get_hf_processor(Gemma3Processor, **kwargs)
-
- def get_supported_mm_limits(self) -> Mapping[str, int | None]:
- return {"image": None}
-
- def _resolve_image_kwargs(
- self,
- processor: Gemma3Processor,
- keys: set[str],
- ) -> dict[str, Any]:
- image_processor = processor.image_processor
- kwargs = processor._merge_kwargs(
- Gemma3ProcessorKwargs,
- tokenizer_init_kwargs=processor.tokenizer.init_kwargs,
- )
-
- images_kwargs = kwargs["images_kwargs"]
-
- def _resolve_kw(key: str):
- val = getattr(image_processor, key)
- if val is None:
- val = images_kwargs[key]
-
- return val
-
- return {k: _resolve_kw(k) for k in keys}
-
- def get_num_crops(
- self,
- *,
- image_width: int,
- image_height: int,
- processor: Gemma3Processor | None,
- ) -> int:
- if processor is None:
- processor = self.get_hf_processor()
-
- images_kwargs = self._resolve_image_kwargs(
- processor,
- {
- "do_pan_and_scan",
- "pan_and_scan_min_crop_size",
- "pan_and_scan_max_num_crops",
- "pan_and_scan_min_ratio_to_activate",
- },
- )
-
- do_pan_and_scan = images_kwargs["do_pan_and_scan"]
- pan_and_scan_min_crop_size = images_kwargs["pan_and_scan_min_crop_size"]
- pan_and_scan_max_num_crops = images_kwargs["pan_and_scan_max_num_crops"]
- pan_and_scan_min_ratio_to_activate = images_kwargs[
- "pan_and_scan_min_ratio_to_activate"
- ]
-
- if not do_pan_and_scan:
- return 0
-
- if envs.VLLM_USE_V1:
- logger.warning_once(
- "`do_pan_and_scan=True` has suboptimal results on V1 "
- "because of the simplified attention pattern being used."
- )
-
- # Based on Gemma3ImageProcessor.pan_and_scan
- if image_width >= image_height:
- if image_width / image_height < pan_and_scan_min_ratio_to_activate:
- return 0
-
- num_crops_w = min(
- int(math.floor(image_width / pan_and_scan_min_crop_size)),
- int(math.floor(image_width / image_height + 0.5)),
- )
-
- num_crops_w = max(2, num_crops_w)
- num_crops_w = min(pan_and_scan_max_num_crops, num_crops_w)
- num_crops_h = 1
- else:
- if image_height / image_width < pan_and_scan_min_ratio_to_activate:
- return 0
-
- num_crops_h = min(
- int(math.floor(image_height / pan_and_scan_min_crop_size)),
- int(math.floor(image_height / image_width + 0.5)),
- )
-
- num_crops_h = max(2, num_crops_h)
- num_crops_h = min(pan_and_scan_max_num_crops, num_crops_h)
- num_crops_w = 1
-
- crop_size_w = int(math.ceil(image_width / num_crops_w))
- crop_size_h = int(math.ceil(image_height / num_crops_h))
-
- if min(crop_size_w, crop_size_h) < pan_and_scan_min_crop_size:
- return 0
-
- return num_crops_w * num_crops_h
-
- def get_image_repl(
- self,
- *,
- image_width: int,
- image_height: int,
- processor: Gemma3Processor | None,
- ) -> PromptUpdateDetails[str]:
- if processor is None:
- processor = self.get_hf_processor()
-
- boi_token = processor.boi_token
-
- num_crops = self.get_num_crops(
- image_width=image_width,
- image_height=image_height,
- processor=processor,
- )
-
- if num_crops == 0:
- image_text = boi_token
- else:
- crops_image_tokens = " ".join(boi_token for _ in range(num_crops))
- image_text = (
- f"Here is the original image {boi_token} and here are some "
- f"crops to help you see better {crops_image_tokens}"
- )
-
- repl_full = image_text.replace(boi_token, processor.full_image_sequence)
-
- tokenizer = processor.tokenizer
- vocab = tokenizer.get_vocab()
- image_token_id = vocab[tokenizer.image_token]
-
- return PromptUpdateDetails.select_token_id(repl_full, image_token_id)
-
- def get_num_image_tokens(
- self,
- *,
- image_width: int,
- image_height: int,
- processor: Gemma3Processor | None,
- ) -> int:
- if processor is None:
- processor = self.get_hf_processor()
-
- num_crops = self.get_num_crops(
- image_width=image_width,
- image_height=image_height,
- processor=processor,
- )
- image_seq_len = processor.image_seq_length
-
- return (num_crops + 1) * image_seq_len
-
- def get_image_size_with_most_features(self) -> ImageSize:
- processor = self.get_hf_processor()
-
- images_kwargs = self._resolve_image_kwargs(
- processor, {"pan_and_scan_max_num_crops"}
- )
- max_num_crops = images_kwargs["pan_and_scan_max_num_crops"]
-
- # Result in the max possible feature size (h:w = max_num_crops:1)
- return ImageSize(height=50 * max_num_crops, width=50)
-
-
-class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):
- def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
- num_images = mm_counts.get("image", 0)
-
- processor = self.info.get_hf_processor()
- image_token = processor.boi_token
-
- return image_token * num_images
-
- def get_dummy_mm_data(
- self,
- seq_len: int,
- mm_counts: Mapping[str, int],
- mm_options: Mapping[str, BaseDummyOptions] | None = None,
- ) -> MultiModalDataDict:
- num_images = mm_counts.get("image", 0)
-
- target_width, target_height = self.info.get_image_size_with_most_features()
-
- image_overrides = mm_options.get("image") if mm_options else None
-
- return {
- "image": self._get_dummy_images(
- width=target_width,
- height=target_height,
- num_images=num_images,
- overrides=image_overrides,
- )
- }
-
-
-class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
- def _call_hf_processor(
- self,
- prompt: str,
- mm_data: Mapping[str, object],
- mm_kwargs: Mapping[str, object],
- tok_kwargs: Mapping[str, object],
- ) -> BatchFeature:
- processed_outputs = super()._call_hf_processor(
- prompt,
- mm_data,
- mm_kwargs,
- tok_kwargs,
- )
-
- # HF processor pops the `num_crops` kwarg, which is needed by vLLM
- if (images := mm_data.get("images")) is not None:
- parsed_images = (
- self._get_data_parser()
- .parse_mm_data({"image": images})
- .get_items("image", ImageProcessorItems)
- )
- image_sizes = [
- parsed_images.get_image_size(i) for i in range(len(parsed_images))
- ]
- hf_processor = self.info.get_hf_processor(**mm_kwargs)
-
- num_crops = [
- self.info.get_num_crops(
- image_width=size.width,
- image_height=size.height,
- processor=hf_processor,
- )
- for size in image_sizes
- ]
- processed_outputs["num_patches"] = torch.tensor(num_crops) + 1
-
- return processed_outputs
-
- def _get_mm_fields_config(
- self,
- hf_inputs: BatchFeature,
- hf_processor_mm_kwargs: Mapping[str, object],
- ) -> Mapping[str, MultiModalFieldConfig]:
- num_patches = hf_inputs.get("num_patches", torch.empty(0))
-
- return dict(
- pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches),
- num_patches=MultiModalFieldConfig.batched("image"),
- )
-
- def _get_prompt_updates(
- self,
- mm_items: MultiModalDataItems,
- hf_processor_mm_kwargs: Mapping[str, Any],
- out_mm_kwargs: MultiModalKwargsItems,
- ) -> Sequence[PromptUpdate]:
- hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
- image_token = hf_processor.boi_token
-
- def get_replacement_gemma3(item_idx: int):
- images = mm_items.get_items("image", ImageProcessorItems)
-
- image_size = images.get_image_size(item_idx)
- return self.info.get_image_repl(
- image_width=image_size.width,
- image_height=image_size.height,
- processor=hf_processor,
- )
-
- return [
- PromptReplacement(
- modality="image",
- target=image_token,
- replacement=get_replacement_gemma3,
- )
- ]
-
- def _apply_token_matches(
- self,
- prompt: list[int],
- mm_prompt_updates: MultiModalPromptUpdates,
- ) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]:
- token_ids, res = super()._apply_token_matches(prompt, mm_prompt_updates)
-
- # "\n\n\n" and "\n\n\n\n" are single tokens
- # Since our replacement can insert "\n\n" next to "\n"
- # tokens, we have to combine them to be consistent with
- # the output of the tokenizer
- tokenizer = self.info.get_tokenizer()
- vocab = tokenizer.get_vocab()
- newline_1 = vocab["\n"]
- newline_2 = vocab["\n\n"]
- newline_3 = vocab["\n\n\n"]
- newline_4 = vocab["\n\n\n\n"]
-
- token_ids = replace_token_matches(
- token_ids,
- [newline_1, newline_2],
- [newline_3],
- )
- token_ids = replace_token_matches(
- token_ids,
- [newline_2, newline_1],
- [newline_3],
- )
- token_ids = replace_token_matches(
- token_ids,
- [newline_2, newline_2],
- [newline_4],
- )
-
- return token_ids, res
-
- def _find_mm_placeholders(
- self,
- new_token_ids: list[int],
- mm_prompt_updates: MultiModalPromptUpdates,
- ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
- # We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n"
- tokenizer = self.info.get_tokenizer()
- vocab = tokenizer.get_vocab()
- newline_1 = vocab["\n"]
- newline_2 = vocab["\n\n"]
- newline_3 = vocab["\n\n\n"]
- newline_4 = vocab["\n\n\n\n"]
-
- def get_repl_toks(tok: int) -> list[int]:
- if tok == newline_3:
- return [newline_1, newline_2]
- if tok == newline_4:
- return [newline_2, newline_2]
-
- return [tok]
-
- repl_token_ids = list[int]()
- repl_orig_idxs = list[int]()
- for orig_idx, orig_tok in enumerate(new_token_ids):
- repl_toks = get_repl_toks(orig_tok)
- repl_token_ids.extend(repl_toks)
- repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks)))
-
- repls = super()._find_mm_placeholders(repl_token_ids, mm_prompt_updates)
-
- return {
- modality: [
- PlaceholderFeaturesInfo(
- modality=p.modality,
- item_idx=p.item_idx,
- start_idx=repl_orig_idxs[p.start_idx],
- tokens=p.tokens,
- is_embed=p.is_embed,
- )
- for p in placeholders
- ]
- for modality, placeholders in repls.items()
- }
-
-
-class Gemma3MultiModalProjector(nn.Module):
- def __init__(self, config: Gemma3Config):
- super().__init__()
-
- self.mm_input_projection_weight = nn.Parameter(
- torch.zeros(
- config.vision_config.hidden_size, config.text_config.hidden_size
- )
- )
-
- self.mm_soft_emb_norm = GemmaRMSNorm(
- config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps
- )
-
- self.patches_per_image = int(
- config.vision_config.image_size // config.vision_config.patch_size
- )
- self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
- self.kernel_size = self.patches_per_image // self.tokens_per_side
- self.avg_pool = nn.AvgPool2d(
- kernel_size=self.kernel_size, stride=self.kernel_size
- )
-
- def forward(self, vision_outputs: torch.Tensor):
- batch_size, _, seq_length = vision_outputs.shape
-
- reshaped_vision_outputs = vision_outputs.transpose(1, 2)
- reshaped_vision_outputs = reshaped_vision_outputs.reshape(
- batch_size, seq_length, self.patches_per_image, self.patches_per_image
- )
- reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
-
- pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
- pooled_vision_outputs = pooled_vision_outputs.flatten(2)
- pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
-
- normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs)
-
- projected_vision_outputs = torch.matmul(
- normed_vision_outputs, self.mm_input_projection_weight
- )
- return projected_vision_outputs.type_as(vision_outputs)
-
-
-@MULTIMODAL_REGISTRY.register_processor(
- Gemma3MultiModalProcessor,
- info=Gemma3ProcessingInfo,
- dummy_inputs=Gemma3DummyInputsBuilder,
-)
-class Gemma3ForConditionalGeneration(
- nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
-):
- merge_by_field_config = True
-
- packed_modules_mapping = {
- "qkv_proj": [
- "q_proj",
- "k_proj",
- "v_proj",
- ],
- "gate_up_proj": [
- "gate_proj",
- "up_proj",
- ],
- }
-
- hf_to_vllm_mapper = WeightsMapper(
- orig_to_new_prefix={
- # mapping for new names in checkpoint saved after transformers v4.52
- "model.language_model.": "language_model.model.",
- "model.vision_tower.": "vision_tower.",
- "model.multi_modal_projector.": "multi_modal_projector.",
- "lm_head.": "language_model.lm_head.",
- }
- )
-
- @classmethod
- def get_placeholder_str(cls, modality: str, i: int) -> str | None:
- if modality.startswith("image"):
- return ""
-
- raise ValueError("Only image modality is supported")
-
- def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
- super().__init__()
- config = vllm_config.model_config.hf_config
- quant_config = vllm_config.quant_config
- multimodal_config = vllm_config.model_config.multimodal_config
- self.config = config
- self.quant_config = quant_config
- self.multimodal_config = multimodal_config
-
- self.vision_tower = SiglipVisionModel(
- config.vision_config,
- quant_config,
- prefix=maybe_prefix(prefix, "vision_tower"),
- )
- self.multi_modal_projector = Gemma3MultiModalProjector(config)
-
- self.language_model = init_vllm_registered_model(
- vllm_config=vllm_config,
- hf_config=config.text_config,
- prefix=maybe_prefix(prefix, "language_model"),
- architectures=["Gemma3ForCausalLM"],
- )
- logit_scale = getattr(config, "logit_scale", 1.0)
-
- if hasattr(self.language_model, "logits_processor"):
- # The logits processor can be unset if we're using
- # automatic conversion to pooling model.
- self.language_model.logits_processor.scale *= logit_scale
-
- self.make_empty_intermediate_tensors = (
- self.language_model.make_empty_intermediate_tensors
- )
-
- @property
- def dtype(self):
- return next(self.parameters()).dtype
-
- def _parse_and_validate_image_input(
- self, **kwargs: object
- ) -> Gemma3ImageInputs | None:
- pixel_values = kwargs.pop("pixel_values", None)
- num_patches = kwargs.pop("num_patches", None)
- image_embeds = kwargs.pop("image_embeds", None)
- assert image_embeds is None, "Gemma3 does not support image_embeds."
- if pixel_values is None:
- return None
-
- image_size = self.config.vision_config.image_size
-
- return Gemma3ImagePixelInputs(
- pixel_values=pixel_values,
- num_patches=num_patches,
- resolve_bindings={"h": image_size, "w": image_size},
- )
-
- def _image_pixels_to_features(
- self,
- vision_tower: SiglipVisionModel,
- pixel_values: torch.Tensor,
- ) -> torch.Tensor:
- return vision_tower(pixel_values)
-
- def _process_image_input(
- self,
- image_input: Gemma3ImageInputs,
- ) -> list[torch.Tensor]:
- assert self.vision_tower is not None
-
- pixel_values = image_input["pixel_values"]
- num_patches = image_input["num_patches"]
-
- image_features = self._image_pixels_to_features(
- self.vision_tower,
- pixel_values,
- )
- image_embeds = self.multi_modal_projector(image_features)
-
- return [e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist())]
-
- def get_language_model(self) -> torch.nn.Module:
- return self.language_model
-
- def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
- image_input = self._parse_and_validate_image_input(**kwargs)
- if image_input is None:
- return []
-
- return self._process_image_input(image_input)
-
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- intermediate_tensors: IntermediateTensors | None = None,
- inputs_embeds: torch.Tensor | None = None,
- **kwargs: object,
- ) -> IntermediateTensors:
- if intermediate_tensors is not None:
- inputs_embeds = None
-
- hidden_states = self.language_model.model(
- input_ids,
- positions,
- intermediate_tensors,
- inputs_embeds=inputs_embeds,
- **kwargs,
- )
-
- return hidden_states
-
- def prepare_attn_masks(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- mask_dtype: torch.dtype,
- **kwargs,
- ):
- kwargs["has_images"] = True
- # NOTE(woosuk): Here, we distinguish the sequences by the position id 0.
- # This is a HACK. Fix this.
- start_indices = (positions == 0).cpu().nonzero()
- num_seqs = len(start_indices)
- seq_lens = []
- for i in range(num_seqs):
- start_idx = start_indices[i].item()
- if i < num_seqs - 1:
- end_idx = start_indices[i + 1].item()
- else:
- end_idx = len(input_ids)
- seq_lens.append(end_idx - start_idx)
- kwargs["seq_lens"] = seq_lens
-
- global_attn_masks = []
- local_attn_masks = []
- start_idx = 0
- for seq_len in seq_lens:
- end_idx = start_idx + seq_len
- input_token_ids = input_ids[start_idx:end_idx]
- start_idx = end_idx
- # Create a global causal mask.
- global_attn_mask = torch.empty(
- 1,
- 1,
- seq_len,
- seq_len,
- dtype=mask_dtype,
- device=input_ids.device,
- )
- global_attn_mask.fill_(float("-inf"))
- # Fill the lower triangle with 0.
- global_attn_mask = global_attn_mask.triu(diagonal=1)
-
- # Consider the bidirectional attention between image tokens.
- img_mask = torch.zeros_like(global_attn_mask)
- img_pos = input_token_ids == self.config.image_token_index
- img_mask[:, :, :, img_pos] += 1
- img_mask[:, :, img_pos, :] += 1
- global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
- global_attn_masks.append(global_attn_mask)
-
- sliding_window = self.config.text_config.sliding_window
- if sliding_window is not None:
- # Create a local causal mask with sliding window (1024).
- local_attn_mask = torch.ones_like(global_attn_mask)
- local_attn_mask = torch.tril(local_attn_mask, diagonal=-sliding_window)
- local_attn_mask = torch.where(
- local_attn_mask == 0, global_attn_mask, float("-inf")
- )
- local_attn_masks.append(local_attn_mask)
- kwargs["global_attn_masks"] = global_attn_masks
- kwargs["local_attn_masks"] = local_attn_masks
- return kwargs
-
- def compute_logits(
- self,
- hidden_states: torch.Tensor,
- ) -> torch.Tensor | None:
- return self.language_model.compute_logits(hidden_states)
-
- def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
- loader = AutoWeightsLoader(self)
- return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
-
- def get_mm_mapping(self) -> MultiModelKeys:
- """
- Get the module prefix in multimodal models
- """
- return MultiModelKeys.from_string_field(
- language_model="language_model",
- connector="multi_modal_projector",
- tower_model="vision_tower",
- )
diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py
deleted file mode 100644
index fb0b4b290467..000000000000
--- a/vllm/model_executor/models/paligemma.py
+++ /dev/null
@@ -1,412 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from collections.abc import Iterable, Mapping, Sequence
-from typing import Annotated, Literal, TypeAlias
-
-import torch
-from torch import nn
-from transformers import BatchFeature, PaliGemmaConfig
-
-from vllm.config import VllmConfig
-from vllm.config.multimodal import BaseDummyOptions
-from vllm.logger import init_logger
-from vllm.multimodal import MULTIMODAL_REGISTRY
-from vllm.multimodal.inputs import (
- MultiModalDataDict,
- MultiModalFieldConfig,
- MultiModalInputs,
- MultiModalKwargsItems,
- MultiModalUUIDDict,
-)
-from vllm.multimodal.parse import (
- ImageEmbeddingItems,
- ImageProcessorItems,
- MultiModalDataItems,
-)
-from vllm.multimodal.processing import (
- BaseMultiModalProcessor,
- BaseProcessingInfo,
- PromptIndexTargets,
- PromptInsertion,
- PromptUpdate,
- PromptUpdateDetails,
-)
-from vllm.multimodal.profiling import BaseDummyInputsBuilder
-from vllm.sequence import IntermediateTensors
-from vllm.utils.tensor_schema import TensorSchema, TensorShape
-
-from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
-from .siglip import SiglipVisionModel
-from .utils import (
- AutoWeightsLoader,
- WeightsMapper,
- flatten_bn,
- init_vllm_registered_model,
- maybe_prefix,
-)
-from .vision import get_vision_encoder_info
-
-logger = init_logger(__name__)
-
-
-class PaliGemmaImagePixelInputs(TensorSchema):
- """
- Dimensions:
- - bn: Batch size * number of images
- - c: Number of channels (3)
- - h: Height
- - w: Width
- """
-
- type: Literal["pixel_values"] = "pixel_values"
- data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
-
-
-class PaliGemmaImageEmbeddingInputs(TensorSchema):
- """
- Dimensions:
- - bn: Batch size * number of images
- - ifs: Image feature size
- - hs: Hidden size (must match language model backbone)
- """
-
- type: Literal["image_embeds"] = "image_embeds"
- data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
-
-
-PaliGemmaImageInputs: TypeAlias = (
- PaliGemmaImagePixelInputs | PaliGemmaImageEmbeddingInputs
-)
-
-
-class PaliGemmaMultiModalProjector(nn.Module):
- def __init__(self, vision_hidden_size: int, projection_dim: int):
- super().__init__()
-
- self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True)
-
- def forward(self, image_features: torch.Tensor) -> torch.Tensor:
- hidden_states = self.linear(image_features)
- return hidden_states
-
-
-class PaliGemmaProcessingInfo(BaseProcessingInfo):
- def get_hf_config(self):
- return self.ctx.get_hf_config(PaliGemmaConfig)
-
- def get_vision_encoder_info(self):
- return get_vision_encoder_info(self.get_hf_config())
-
- def get_supported_mm_limits(self) -> Mapping[str, int | None]:
- return {"image": 1}
-
- def get_num_image_tokens(
- self,
- *,
- image_width: int,
- image_height: int,
- ) -> int:
- vision_encoder_info = self.get_vision_encoder_info()
-
- return vision_encoder_info.get_num_image_tokens(
- image_width=image_width,
- image_height=image_height,
- )
-
-
-class PaliGemmaDummyInputsBuilder(BaseDummyInputsBuilder[PaliGemmaProcessingInfo]):
- def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
- return ""
-
- def get_dummy_mm_data(
- self,
- seq_len: int,
- mm_counts: Mapping[str, int],
- mm_options: Mapping[str, BaseDummyOptions] | None = None,
- ) -> MultiModalDataDict:
- hf_config = self.info.get_hf_config()
- vision_config = hf_config.vision_config
- max_image_size = vision_config.image_size
-
- num_images = mm_counts.get("image", 0)
-
- image_overrides = mm_options.get("image") if mm_options else None
-
- return {
- "image": self._get_dummy_images(
- width=max_image_size,
- height=max_image_size,
- num_images=num_images,
- overrides=image_overrides,
- )
- }
-
-
-class PaliGemmaMultiModalProcessor(BaseMultiModalProcessor[PaliGemmaProcessingInfo]):
- def _call_hf_processor(
- self,
- prompt: str,
- mm_data: Mapping[str, object],
- mm_kwargs: Mapping[str, object],
- tok_kwargs: Mapping[str, object],
- ) -> BatchFeature:
- tokenizer = self.info.get_tokenizer()
- if not mm_data:
- prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
- return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
-
- return super()._call_hf_processor(
- prompt=prompt,
- mm_data=mm_data,
- mm_kwargs=mm_kwargs,
- tok_kwargs=tok_kwargs,
- )
-
- def _get_mm_fields_config(
- self,
- hf_inputs: BatchFeature,
- hf_processor_mm_kwargs: Mapping[str, object],
- ) -> Mapping[str, MultiModalFieldConfig]:
- return dict(pixel_values=MultiModalFieldConfig.batched("image"))
-
- def _get_prompt_updates(
- self,
- mm_items: MultiModalDataItems,
- hf_processor_mm_kwargs: Mapping[str, object],
- out_mm_kwargs: MultiModalKwargsItems,
- ) -> Sequence[PromptUpdate]:
- hf_config = self.info.get_hf_config()
- image_token_id = hf_config.image_token_index
-
- tokenizer = self.info.get_tokenizer()
-
- bos_token_id = tokenizer.bos_token_id
- assert isinstance(bos_token_id, int)
-
- def get_insertion(item_idx: int):
- images = mm_items.get_items(
- "image", (ImageEmbeddingItems, ImageProcessorItems)
- )
-
- if isinstance(images, ImageEmbeddingItems):
- num_image_tokens = images.get_feature_size(item_idx)
- else:
- image_size = images.get_image_size(item_idx)
- num_image_tokens = self.info.get_num_image_tokens(
- image_width=image_size.width,
- image_height=image_size.height,
- )
-
- image_tokens = [image_token_id] * num_image_tokens
-
- return PromptUpdateDetails.select_token_id(
- image_tokens + [bos_token_id],
- embed_token_id=image_token_id,
- )
-
- # Paligemma 1 and 2 have different tokenizer.add_bos_token
- # Insert *n + after for Paligemma 1
- # Insert *n + for Paligemma 2
- return [
- PromptInsertion(
- modality="image",
- target=PromptIndexTargets.prefix(
- [bos_token_id] if tokenizer.add_bos_token else []
- ),
- insertion=get_insertion,
- )
- ]
-
- def apply(
- self,
- prompt: str | list[int],
- mm_data: MultiModalDataDict,
- hf_processor_mm_kwargs: Mapping[str, object],
- tokenization_kwargs: Mapping[str, object] | None = None,
- mm_uuids: MultiModalUUIDDict | None = None,
- ) -> MultiModalInputs:
- mm_inputs = super().apply(
- prompt,
- mm_data,
- hf_processor_mm_kwargs,
- tokenization_kwargs,
- mm_uuids=mm_uuids,
- )
- prompt_token_ids = mm_inputs["prompt_token_ids"]
-
- tokenizer = self.info.get_tokenizer()
- newline_prompt = "\n"
- newline_token_id = tokenizer.encode(newline_prompt)[-1] # 108
- # Force to add newline at the end of prompt for paligemma's format
- # This step can NOT be replacemented by current PromptUpdate methods
- if len(prompt_token_ids) and prompt_token_ids[-1] != newline_token_id:
- prompt_token_ids.append(newline_token_id)
- mm_inputs["prompt_token_ids"] = prompt_token_ids
-
- return mm_inputs
-
-
-@MULTIMODAL_REGISTRY.register_processor(
- PaliGemmaMultiModalProcessor,
- info=PaliGemmaProcessingInfo,
- dummy_inputs=PaliGemmaDummyInputsBuilder,
-)
-class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
- packed_modules_mapping = {
- "qkv_proj": [
- "q_proj",
- "k_proj",
- "v_proj",
- ],
- "gate_up_proj": [
- "gate_proj",
- "up_proj",
- ],
- }
-
- hf_to_vllm_mapper = WeightsMapper(
- orig_to_new_prefix={
- # mapping for new names in checkpoint saved after transformers v4.52
- "model.language_model.": "language_model.model.",
- "model.vision_tower.": "vision_tower.",
- "model.multi_modal_projector.": "multi_modal_projector.",
- "lm_head.": "language_model.lm_head.",
- }
- )
-
- @classmethod
- def get_placeholder_str(cls, modality: str, i: int) -> str | None:
- if modality.startswith("image"):
- return None
-
- raise ValueError("Only image modality is supported")
-
- def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
- super().__init__()
- config = vllm_config.model_config.hf_config
- quant_config = vllm_config.quant_config
- multimodal_config = vllm_config.model_config.multimodal_config
- self.config = config
- self.multimodal_config = multimodal_config
-
- self.vision_tower = SiglipVisionModel(
- config.vision_config,
- quant_config,
- prefix=maybe_prefix(prefix, "vision_tower"),
- )
- self.multi_modal_projector = PaliGemmaMultiModalProjector(
- vision_hidden_size=config.vision_config.hidden_size,
- projection_dim=config.vision_config.projection_dim,
- )
-
- self.quant_config = quant_config
-
- if config.text_config.model_type == "gemma":
- config.text_config.architectures = ["GemmaForCausalLM"]
- else:
- config.text_config.architectures = ["Gemma2ForCausalLM"]
- self.language_model = init_vllm_registered_model(
- vllm_config=vllm_config,
- hf_config=config.text_config,
- prefix=maybe_prefix(prefix, "language_model"),
- )
- logit_scale = getattr(config, "logit_scale", 1.0)
- self.language_model.logits_processor.scale *= logit_scale
-
- self.make_empty_intermediate_tensors = (
- self.language_model.make_empty_intermediate_tensors
- )
-
- def _parse_and_validate_image_input(
- self, **kwargs: object
- ) -> PaliGemmaImageInputs | None:
- pixel_values = kwargs.pop("pixel_values", None)
- image_embeds = kwargs.pop("image_embeds", None)
-
- if pixel_values is None and image_embeds is None:
- return None
-
- if pixel_values is not None:
- pixel_values = flatten_bn(pixel_values, concat=True)
-
- h = w = self.config.vision_config.image_size
- return PaliGemmaImagePixelInputs(
- type="pixel_values",
- data=pixel_values,
- resolve_bindings={"h": h, "w": w},
- )
-
- if image_embeds is not None:
- image_embeds = flatten_bn(image_embeds, concat=True)
-
- return PaliGemmaImageEmbeddingInputs(
- type="image_embeds",
- data=image_embeds,
- )
-
- raise AssertionError("This line should be unreachable.")
-
- def _image_pixels_to_features(
- self,
- vision_tower: SiglipVisionModel,
- pixel_values: torch.Tensor,
- ) -> torch.Tensor:
- target_dtype = vision_tower.get_input_embeddings().weight.dtype
- image_features = vision_tower(pixel_values.to(dtype=target_dtype))
-
- return image_features
-
- def _process_image_input(
- self,
- image_input: PaliGemmaImageInputs,
- ) -> torch.Tensor:
- if image_input["type"] == "image_embeds":
- return image_input["data"]
-
- assert self.vision_tower is not None
- pixel_values = image_input["data"]
- image_features = self._image_pixels_to_features(
- self.vision_tower,
- pixel_values,
- )
-
- return self.multi_modal_projector(image_features)
-
- def get_language_model(self) -> torch.nn.Module:
- return self.language_model
-
- def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
- image_input = self._parse_and_validate_image_input(**kwargs)
- if image_input is None:
- return []
- vision_embeddings = self._process_image_input(image_input)
- # https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa
- vision_embeddings = vision_embeddings * (self.config.hidden_size**-0.5)
- return vision_embeddings
-
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- intermediate_tensors: IntermediateTensors | None = None,
- inputs_embeds: torch.Tensor | None = None,
- **kwargs: object,
- ) -> IntermediateTensors:
- if intermediate_tensors is not None:
- inputs_embeds = None
-
- hidden_states = self.language_model.model(
- input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
- )
-
- return hidden_states
-
- def compute_logits(
- self,
- hidden_states: torch.Tensor,
- ) -> torch.Tensor | None:
- return self.language_model.compute_logits(hidden_states)
-
- def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
- loader = AutoWeightsLoader(self)
- return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index 4171ebdbde6d..ea2f14916058 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -263,7 +263,6 @@ _MULTIMODAL_MODELS = {
"Ernie4_5_VLMoeForConditionalGeneration",
),
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
- "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501
"Gemma3nForConditionalGeneration": (
"gemma3n_mm",
"Gemma3nForConditionalGeneration",
@@ -329,10 +328,6 @@ _MULTIMODAL_MODELS = {
"NVLM_D": ("nvlm_d", "NVLM_D_Model"),
"Ovis": ("ovis", "Ovis"),
"Ovis2_5": ("ovis2_5", "Ovis2_5"),
- "PaliGemmaForConditionalGeneration": (
- "paligemma",
- "PaliGemmaForConditionalGeneration",
- ),
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
"Phi4MultimodalForCausalLM": ("phi4_multimodal", "Phi4MultimodalForCausalLM"), # noqa: E501
@@ -405,6 +400,14 @@ _TRANSFORMERS_SUPPORTED_MODELS = {
"transformers",
"TransformersMultiModalForCausalLM",
),
+ "Gemma3ForConditionalGeneration": (
+ "transformers",
+ "TransformersMultiModalForCausalLM",
+ ),
+ "PaliGemmaForConditionalGeneration": (
+ "transformers",
+ "TransformersMultiModalForCausalLM",
+ ),
}
_TRANSFORMERS_BACKEND_MODELS = {
diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py
index b25b96889309..d1d94048c0c4 100644
--- a/vllm/platforms/rocm.py
+++ b/vllm/platforms/rocm.py
@@ -59,9 +59,6 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = {
"Qwen2ForCausalLM": _ROCM_SWA_REASON,
"MistralForCausalLM": _ROCM_SWA_REASON,
"MixtralForCausalLM": _ROCM_SWA_REASON,
- "PaliGemmaForConditionalGeneration": (
- "ROCm flash attention does not yet fully support 32-bit precision on PaliGemma"
- ),
"Phi3VForCausalLM": (
"ROCm Triton flash attention may run into compilation errors due to "
"excessive use of shared memory. If this happens, disable Triton FA "