From 8c017b34908f8d4a877d862dd21b99aef7057c55 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 17 Oct 2025 13:03:35 +0800 Subject: [PATCH] [Model] Always use Transformers backend for PaliGemma and Gemma3-MM (#26715) Signed-off-by: DarkLight1337 --- docs/models/hardware_supported_models/tpu.md | 4 +- docs/models/supported_models.md | 23 +- examples/offline_inference/vision_language.py | 3 +- .../models/language/generation/test_gemma.py | 16 +- .../multimodal/generation/test_common.py | 74 +- .../generation/vlm_utils/model_utils.py | 10 - .../multimodal/processing/test_common.py | 4 - .../processing/test_tensor_schema.py | 1 - vllm/model_executor/models/gemma3_mm.py | 710 ------------------ vllm/model_executor/models/paligemma.py | 412 ---------- vllm/model_executor/models/registry.py | 13 +- vllm/platforms/rocm.py | 3 - 12 files changed, 54 insertions(+), 1219 deletions(-) delete mode 100644 vllm/model_executor/models/gemma3_mm.py delete mode 100644 vllm/model_executor/models/paligemma.py diff --git a/docs/models/hardware_supported_models/tpu.md b/docs/models/hardware_supported_models/tpu.md index 7b0a5ba6e72d..8d3e28c259ec 100644 --- a/docs/models/hardware_supported_models/tpu.md +++ b/docs/models/hardware_supported_models/tpu.md @@ -16,8 +16,8 @@ | meta-llama/Llama-4-* | Llama4ForConditionalGeneration | ❌ | | microsoft/Phi-3-mini-128k-instruct | Phi3ForCausalLM | 🟨 | | microsoft/phi-4 | Phi3ForCausalLM | ❌ | -| google/gemma-3-27b-it | Gemma3ForConditionalGeneration | 🟨 | -| google/gemma-3-4b-it | Gemma3ForConditionalGeneration | ❌ | +| google/gemma-3-27b-it | TransformersForMultimodalLM | 🟨 | +| google/gemma-3-4b-it | TransformersForMultimodalLM | ❌ | | deepseek-ai/DeepSeek-R1 | DeepseekV3ForCausalLM | ❌ | | deepseek-ai/DeepSeek-V3 | DeepseekV3ForCausalLM | ❌ | | RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8 | LlamaForCausalLM | ✅ | diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index f98b06188ede..3bba63dda03d 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -650,7 +650,6 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `DeepseekVLV2ForCausalLM`^ | DeepSeek-VL2 | T + I+ | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | ✅︎ | | `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I+/ V+ | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ | | `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | -| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I+ | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | | `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | | `GLM4VForCausalLM`^ | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ | | `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + IE+ + VE+ | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ | @@ -679,7 +678,6 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `NVLM_D_Model` | NVLM-D 1.0 | T + I+ | `nvidia/NVLM-D-72B`, etc. | | ✅︎ | | `Ovis` | Ovis2, Ovis1.6 | T + I+ | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ | | `Ovis2_5` | Ovis2.5 | T + I+ + V | `AIDC-AI/Ovis2.5-9B`, etc. | | | -| `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + IE | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ | | `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + IE+ | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | | `Phi4MMForCausalLM` | Phi-4-multimodal | T + I+ / T + A+ / I+ + A+ | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | | `Phi4MultimodalForCausalLM` | Phi-4-multimodal (HF Transformers) | T + I+ / T + A+ / I+ + A+ | `microsoft/Phi-4-multimodal-instruct` (with revision `refs/pr/70`), etc. | ✅︎ | ✅︎ | @@ -704,6 +702,8 @@ Some models are supported only via the [Transformers backend](#transformers). Th | Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | |--------------|--------|--------|-------------------|-----------------------------|-----------------------------------------| | `Emu3ForConditionalGeneration` | Emu3 | T + I | `BAAI/Emu3-Chat-hf` | ✅︎ | ✅︎ | +| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I+ | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | +| `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + IE | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | ✅︎ | ✅︎ | ^ You need to set the architecture name via `--hf-overrides` to match the one in vLLM.     • For example, to use DeepSeek-VL2 series models: @@ -712,21 +712,7 @@ Some models are supported only via the [Transformers backend](#transformers). Th + Multiple items can be inputted per text prompt for this modality. !!! warning - Both V0 and V1 support `Gemma3ForConditionalGeneration` for text-only inputs. - However, there are differences in how they handle text + image inputs: - - V0 correctly implements the model's attention pattern: - - Uses bidirectional attention between the image tokens corresponding to the same image - - Uses causal attention for other tokens - - Implemented via (naive) PyTorch SDPA with masking tensors - - Note: May use significant memory for long prompts with image - - V1 currently uses a simplified attention pattern: - - Uses causal attention for all tokens, including image tokens - - Generates reasonable outputs but does not match the original model's attention for text + image inputs, especially when `{"do_pan_and_scan": true}` - - Will be updated in the future to support the correct behavior - - This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends. + For `Gemma3ForConditionalGeneration`, `{"do_pan_and_scan": true}` is not supported in Transformers backend yet. !!! note `Gemma3nForConditionalGeneration` is only supported on V1 due to shared KV caching and it depends on `timm>=1.0.17` to make use of its @@ -778,9 +764,6 @@ Some models are supported only via the [Transformers backend](#transformers). Th The official `openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (`HwwwH/MiniCPM-V-2`) for now. For more details, please see: -!!! warning - Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1. - !!! note For Qwen2.5-Omni and Qwen3-Omni, reading audio from video pre-processing (`--mm-processor-kwargs '{"use_audio_in_video": true}'`) is currently work in progress and not yet supported. diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 1f09dabaf74c..a92304837be1 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -248,7 +248,8 @@ def run_gemma3(questions: list[str], modality: str) -> ModelRequestData: model=model_name, max_model_len=2048, max_num_seqs=2, - mm_processor_kwargs={"do_pan_and_scan": True}, + # TODO: Support this in transformers backend + # mm_processor_kwargs={"do_pan_and_scan": True}, limit_mm_per_prompt={modality: 1}, ) diff --git a/tests/models/language/generation/test_gemma.py b/tests/models/language/generation/test_gemma.py index 246b893be315..5108da68cb0b 100644 --- a/tests/models/language/generation/test_gemma.py +++ b/tests/models/language/generation/test_gemma.py @@ -3,7 +3,7 @@ import numpy as np import pytest -MODELS = ["google/gemma-2b", "google/gemma-2-2b", "google/gemma-3-4b-it"] +MODELS = ["google/gemma-2b", "google/gemma-2-2b"] @pytest.mark.parametrize("model", MODELS) @@ -14,14 +14,8 @@ def test_dummy_loader(vllm_runner, monkeypatch, model: str) -> None: model, load_format="dummy", ) as llm: - if model == "google/gemma-3-4b-it": - normalizers = llm.llm.collective_rpc( - lambda self: self.model_runner.model.language_model.model.normalizer.cpu().item() # noqa: E501 - ) - config = llm.llm.llm_engine.model_config.hf_config.text_config - else: - normalizers = llm.llm.collective_rpc( - lambda self: self.model_runner.model.model.normalizer.cpu().item() - ) - config = llm.llm.llm_engine.model_config.hf_config + normalizers = llm.apply_model( + lambda model: model.model.normalizer.cpu().item() + ) + config = llm.llm.llm_engine.model_config.hf_config assert np.allclose(normalizers, config.hidden_size**0.5, rtol=2e-3) diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index 7baca7800bb9..d94a3d5cf3c4 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -113,25 +113,6 @@ VLM_TEST_SETTINGS = { dtype="bfloat16" if current_platform.is_cpu() else "auto", marks=[pytest.mark.core_model, pytest.mark.cpu_model], ), - "paligemma": VLMTestInfo( - models=["google/paligemma-3b-mix-224"], - test_type=VLMTestType.IMAGE, - prompt_formatter=identity, - img_idx_to_prompt=lambda idx: "", - # Paligemma uses its own sample prompts because the default one fails - single_image_prompts=IMAGE_ASSETS.prompts( - { - "stop_sign": "caption es", - "cherry_blossom": "What is in the picture?", - } - ), - auto_cls=AutoModelForImageTextToText, - vllm_output_post_proc=model_utils.paligemma_vllm_to_hf_output, - dtype="bfloat16", - marks=[ - pytest.mark.skip(reason="vLLM does not support PrefixLM attention mask") - ], - ), "qwen2_5_vl": VLMTestInfo( models=["Qwen/Qwen2.5-VL-3B-Instruct"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE, VLMTestType.VIDEO), @@ -196,14 +177,24 @@ VLM_TEST_SETTINGS = { # Gemma3 has bidirectional mask on images "gemma3-transformers": VLMTestInfo( models=["google/gemma-3-4b-it"], - test_type=VLMTestType.IMAGE, - prompt_formatter=lambda vid_prompt: f"<'user\n{vid_prompt}\nmodel\n", # noqa: E501 - max_model_len=4096, + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"user\n{img_prompt}\nmodel\n", # noqa: E501 + single_image_prompts=IMAGE_ASSETS.prompts( + { + "stop_sign": "What's the content in the center of the image?", # noqa: E501 + "cherry_blossom": "What is the season?", + } + ), + multi_image_prompt="Describe the two images in detail.", # noqa: E501 + max_model_len=8192, auto_cls=AutoModelForImageTextToText, + # TODO: Support `do_pan_and_scan` in transformers backend + # patch_hf_runner=model_utils.gemma3_patch_hf_runner, vllm_output_post_proc=model_utils.gemma3_vllm_to_hf_output, image_size_factors=[(0.25, 0.5, 1.0)], vllm_runner_kwargs={ "model_impl": "transformers", + # "mm_processor_kwargs": {"do_pan_and_scan": True}, }, marks=[pytest.mark.core_model], ), @@ -222,6 +213,27 @@ VLM_TEST_SETTINGS = { }, marks=[pytest.mark.core_model], ), + # PaliGemma has PrefixLM attention + "paligemma-transformers": VLMTestInfo( + models=["google/paligemma-3b-mix-224"], + test_type=VLMTestType.IMAGE, + prompt_formatter=identity, + img_idx_to_prompt=lambda idx: "", + # PaliGemma uses its own sample prompts because the default one fails + single_image_prompts=IMAGE_ASSETS.prompts( + { + "stop_sign": "caption es", + "cherry_blossom": "What is in the picture?", + } + ), + auto_cls=AutoModelForImageTextToText, + vllm_output_post_proc=model_utils.paligemma_vllm_to_hf_output, + image_size_factors=[(0.25, 0.5, 1.0)], + vllm_runner_kwargs={ + "model_impl": "transformers", + }, + marks=[pytest.mark.core_model], + ), # Pixel values from processor are not 4D or 5D arrays "qwen2_5_vl-transformers": VLMTestInfo( models=["Qwen/Qwen2.5-VL-3B-Instruct"], @@ -348,24 +360,6 @@ VLM_TEST_SETTINGS = { image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], marks=[large_gpu_mark(min_gb=32)], ), - "gemma3": VLMTestInfo( - models=["google/gemma-3-4b-it"], - test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - prompt_formatter=lambda img_prompt: f"user\n{img_prompt}\nmodel\n", # noqa: E501 - single_image_prompts=IMAGE_ASSETS.prompts( - { - "stop_sign": "What's the content in the center of the image?", # noqa: E501 - "cherry_blossom": "What is the season?", - } - ), - multi_image_prompt="Describe the two images in detail.", # noqa: E501 - max_model_len=4096, - max_num_seqs=2, - auto_cls=AutoModelForImageTextToText, - vllm_runner_kwargs={"mm_processor_kwargs": {"do_pan_and_scan": True}}, - patch_hf_runner=model_utils.gemma3_patch_hf_runner, - num_logprobs=10, - ), "glm4v": VLMTestInfo( models=["zai-org/glm-4v-9b"], test_type=VLMTestType.IMAGE, diff --git a/tests/models/multimodal/generation/vlm_utils/model_utils.py b/tests/models/multimodal/generation/vlm_utils/model_utils.py index c110f5598bee..832954258485 100644 --- a/tests/models/multimodal/generation/vlm_utils/model_utils.py +++ b/tests/models/multimodal/generation/vlm_utils/model_utils.py @@ -328,16 +328,6 @@ def gemma3_patch_hf_runner(hf_model: HfRunner) -> HfRunner: hf_model.processor = processor - orig_generate = hf_model.model.generate - - def _generate(self, *args, **kwargs): - # FIXME: https://github.com/huggingface/transformers/issues/38333 - kwargs["disable_compile"] = True - - return orig_generate(*args, **kwargs) - - hf_model.model.generate = types.MethodType(_generate, hf_model.model) - return hf_model diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 23f183e1d5bb..78bd284b565f 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -222,7 +222,6 @@ def _test_processing_correctness( _ADD_SPECIAL_TOKENS_OVERRIDES = { "ovis": False, "ovis2_5": False, - "paligemma": False, "ultravox": False, "whisper": False, } @@ -333,7 +332,6 @@ def _test_processing_correctness_one( "deepseek-ai/deepseek-vl2-tiny", "baidu/ERNIE-4.5-VL-28B-A3B-PT", "adept/fuyu-8b", - "google/gemma-3-4b-it", "google/gemma-3n-E2B-it", "zai-org/glm-4v-9b", "zai-org/GLM-4.1V-9B-Thinking", @@ -370,8 +368,6 @@ def _test_processing_correctness_one( "AIDC-AI/Ovis1.6-Llama3.2-3B", "AIDC-AI/Ovis2-1B", "AIDC-AI/Ovis2.5-2B", - "google/paligemma-3b-mix-224", - "google/paligemma2-3b-ft-docci-448", "microsoft/Phi-3.5-vision-instruct", "microsoft/Phi-4-multimodal-instruct", "mistralai/Pixtral-12B-2409", diff --git a/tests/models/multimodal/processing/test_tensor_schema.py b/tests/models/multimodal/processing/test_tensor_schema.py index 00c46082df66..166709329a2c 100644 --- a/tests/models/multimodal/processing/test_tensor_schema.py +++ b/tests/models/multimodal/processing/test_tensor_schema.py @@ -48,7 +48,6 @@ ARCH_NEEDS_EXTRAS = [ "Idefics3ForConditionalGeneration", "LlavaForConditionalGeneration", "MiniCPMV", - "PaliGemmaForConditionalGeneration", ] REPO_ID_TO_SKIP = { "nm-testing/pixtral-12b-FP8-dynamic": "duplicated test", diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py deleted file mode 100644 index 7c628fe93ce3..000000000000 --- a/vllm/model_executor/models/gemma3_mm.py +++ /dev/null @@ -1,710 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import math -from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Any, Literal - -import torch -from torch import nn -from transformers import BatchFeature, Gemma3Config, Gemma3Processor -from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs - -import vllm.envs as envs -from vllm.config import VllmConfig -from vllm.config.multimodal import BaseDummyOptions -from vllm.logger import init_logger -from vllm.model_executor.layers.layernorm import GemmaRMSNorm -from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import ( - MultiModalDataDict, - MultiModalFieldConfig, - MultiModalKwargsItems, -) -from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems -from vllm.multimodal.processing import ( - BaseMultiModalProcessor, - BaseProcessingInfo, - MultiModalPromptUpdates, - MultiModalPromptUpdatesApplyResult, - PlaceholderFeaturesInfo, - PromptReplacement, - PromptUpdate, - PromptUpdateDetails, - replace_token_matches, -) -from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.sequence import IntermediateTensors -from vllm.utils.tensor_schema import TensorSchema, TensorShape - -from .interfaces import ( - MultiModalEmbeddings, - SupportsLoRA, - SupportsMultiModal, - SupportsPP, -) -from .siglip import SiglipVisionModel -from .utils import ( - AutoWeightsLoader, - WeightsMapper, - init_vllm_registered_model, - maybe_prefix, -) - -logger = init_logger(__name__) - - -class Gemma3ImagePixelInputs(TensorSchema): - """ - Dimensions: - - p: Number of patches total (over each image over each prompt in the - batch) - - c: Number of channels (3) - - h: Height of each patch - - w: Width of each patch - - bn: Batch size * number of images - """ - - type: Literal["pixel_values"] = "pixel_values" - - pixel_values: Annotated[torch.Tensor, TensorShape("p", 3, "h", "w")] - - num_patches: Annotated[torch.Tensor, TensorShape("bn")] - - -Gemma3ImageInputs = Gemma3ImagePixelInputs - - -class Gemma3ProcessingInfo(BaseProcessingInfo): - def get_hf_config(self): - return self.ctx.get_hf_config(Gemma3Config) - - def get_hf_processor(self, **kwargs: object): - return self.ctx.get_hf_processor(Gemma3Processor, **kwargs) - - def get_supported_mm_limits(self) -> Mapping[str, int | None]: - return {"image": None} - - def _resolve_image_kwargs( - self, - processor: Gemma3Processor, - keys: set[str], - ) -> dict[str, Any]: - image_processor = processor.image_processor - kwargs = processor._merge_kwargs( - Gemma3ProcessorKwargs, - tokenizer_init_kwargs=processor.tokenizer.init_kwargs, - ) - - images_kwargs = kwargs["images_kwargs"] - - def _resolve_kw(key: str): - val = getattr(image_processor, key) - if val is None: - val = images_kwargs[key] - - return val - - return {k: _resolve_kw(k) for k in keys} - - def get_num_crops( - self, - *, - image_width: int, - image_height: int, - processor: Gemma3Processor | None, - ) -> int: - if processor is None: - processor = self.get_hf_processor() - - images_kwargs = self._resolve_image_kwargs( - processor, - { - "do_pan_and_scan", - "pan_and_scan_min_crop_size", - "pan_and_scan_max_num_crops", - "pan_and_scan_min_ratio_to_activate", - }, - ) - - do_pan_and_scan = images_kwargs["do_pan_and_scan"] - pan_and_scan_min_crop_size = images_kwargs["pan_and_scan_min_crop_size"] - pan_and_scan_max_num_crops = images_kwargs["pan_and_scan_max_num_crops"] - pan_and_scan_min_ratio_to_activate = images_kwargs[ - "pan_and_scan_min_ratio_to_activate" - ] - - if not do_pan_and_scan: - return 0 - - if envs.VLLM_USE_V1: - logger.warning_once( - "`do_pan_and_scan=True` has suboptimal results on V1 " - "because of the simplified attention pattern being used." - ) - - # Based on Gemma3ImageProcessor.pan_and_scan - if image_width >= image_height: - if image_width / image_height < pan_and_scan_min_ratio_to_activate: - return 0 - - num_crops_w = min( - int(math.floor(image_width / pan_and_scan_min_crop_size)), - int(math.floor(image_width / image_height + 0.5)), - ) - - num_crops_w = max(2, num_crops_w) - num_crops_w = min(pan_and_scan_max_num_crops, num_crops_w) - num_crops_h = 1 - else: - if image_height / image_width < pan_and_scan_min_ratio_to_activate: - return 0 - - num_crops_h = min( - int(math.floor(image_height / pan_and_scan_min_crop_size)), - int(math.floor(image_height / image_width + 0.5)), - ) - - num_crops_h = max(2, num_crops_h) - num_crops_h = min(pan_and_scan_max_num_crops, num_crops_h) - num_crops_w = 1 - - crop_size_w = int(math.ceil(image_width / num_crops_w)) - crop_size_h = int(math.ceil(image_height / num_crops_h)) - - if min(crop_size_w, crop_size_h) < pan_and_scan_min_crop_size: - return 0 - - return num_crops_w * num_crops_h - - def get_image_repl( - self, - *, - image_width: int, - image_height: int, - processor: Gemma3Processor | None, - ) -> PromptUpdateDetails[str]: - if processor is None: - processor = self.get_hf_processor() - - boi_token = processor.boi_token - - num_crops = self.get_num_crops( - image_width=image_width, - image_height=image_height, - processor=processor, - ) - - if num_crops == 0: - image_text = boi_token - else: - crops_image_tokens = " ".join(boi_token for _ in range(num_crops)) - image_text = ( - f"Here is the original image {boi_token} and here are some " - f"crops to help you see better {crops_image_tokens}" - ) - - repl_full = image_text.replace(boi_token, processor.full_image_sequence) - - tokenizer = processor.tokenizer - vocab = tokenizer.get_vocab() - image_token_id = vocab[tokenizer.image_token] - - return PromptUpdateDetails.select_token_id(repl_full, image_token_id) - - def get_num_image_tokens( - self, - *, - image_width: int, - image_height: int, - processor: Gemma3Processor | None, - ) -> int: - if processor is None: - processor = self.get_hf_processor() - - num_crops = self.get_num_crops( - image_width=image_width, - image_height=image_height, - processor=processor, - ) - image_seq_len = processor.image_seq_length - - return (num_crops + 1) * image_seq_len - - def get_image_size_with_most_features(self) -> ImageSize: - processor = self.get_hf_processor() - - images_kwargs = self._resolve_image_kwargs( - processor, {"pan_and_scan_max_num_crops"} - ) - max_num_crops = images_kwargs["pan_and_scan_max_num_crops"] - - # Result in the max possible feature size (h:w = max_num_crops:1) - return ImageSize(height=50 * max_num_crops, width=50) - - -class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: - num_images = mm_counts.get("image", 0) - - processor = self.info.get_hf_processor() - image_token = processor.boi_token - - return image_token * num_images - - def get_dummy_mm_data( - self, - seq_len: int, - mm_counts: Mapping[str, int], - mm_options: Mapping[str, BaseDummyOptions] | None = None, - ) -> MultiModalDataDict: - num_images = mm_counts.get("image", 0) - - target_width, target_height = self.info.get_image_size_with_most_features() - - image_overrides = mm_options.get("image") if mm_options else None - - return { - "image": self._get_dummy_images( - width=target_width, - height=target_height, - num_images=num_images, - overrides=image_overrides, - ) - } - - -class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): - def _call_hf_processor( - self, - prompt: str, - mm_data: Mapping[str, object], - mm_kwargs: Mapping[str, object], - tok_kwargs: Mapping[str, object], - ) -> BatchFeature: - processed_outputs = super()._call_hf_processor( - prompt, - mm_data, - mm_kwargs, - tok_kwargs, - ) - - # HF processor pops the `num_crops` kwarg, which is needed by vLLM - if (images := mm_data.get("images")) is not None: - parsed_images = ( - self._get_data_parser() - .parse_mm_data({"image": images}) - .get_items("image", ImageProcessorItems) - ) - image_sizes = [ - parsed_images.get_image_size(i) for i in range(len(parsed_images)) - ] - hf_processor = self.info.get_hf_processor(**mm_kwargs) - - num_crops = [ - self.info.get_num_crops( - image_width=size.width, - image_height=size.height, - processor=hf_processor, - ) - for size in image_sizes - ] - processed_outputs["num_patches"] = torch.tensor(num_crops) + 1 - - return processed_outputs - - def _get_mm_fields_config( - self, - hf_inputs: BatchFeature, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> Mapping[str, MultiModalFieldConfig]: - num_patches = hf_inputs.get("num_patches", torch.empty(0)) - - return dict( - pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches), - num_patches=MultiModalFieldConfig.batched("image"), - ) - - def _get_prompt_updates( - self, - mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargsItems, - ) -> Sequence[PromptUpdate]: - hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - image_token = hf_processor.boi_token - - def get_replacement_gemma3(item_idx: int): - images = mm_items.get_items("image", ImageProcessorItems) - - image_size = images.get_image_size(item_idx) - return self.info.get_image_repl( - image_width=image_size.width, - image_height=image_size.height, - processor=hf_processor, - ) - - return [ - PromptReplacement( - modality="image", - target=image_token, - replacement=get_replacement_gemma3, - ) - ] - - def _apply_token_matches( - self, - prompt: list[int], - mm_prompt_updates: MultiModalPromptUpdates, - ) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]: - token_ids, res = super()._apply_token_matches(prompt, mm_prompt_updates) - - # "\n\n\n" and "\n\n\n\n" are single tokens - # Since our replacement can insert "\n\n" next to "\n" - # tokens, we have to combine them to be consistent with - # the output of the tokenizer - tokenizer = self.info.get_tokenizer() - vocab = tokenizer.get_vocab() - newline_1 = vocab["\n"] - newline_2 = vocab["\n\n"] - newline_3 = vocab["\n\n\n"] - newline_4 = vocab["\n\n\n\n"] - - token_ids = replace_token_matches( - token_ids, - [newline_1, newline_2], - [newline_3], - ) - token_ids = replace_token_matches( - token_ids, - [newline_2, newline_1], - [newline_3], - ) - token_ids = replace_token_matches( - token_ids, - [newline_2, newline_2], - [newline_4], - ) - - return token_ids, res - - def _find_mm_placeholders( - self, - new_token_ids: list[int], - mm_prompt_updates: MultiModalPromptUpdates, - ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: - # We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n" - tokenizer = self.info.get_tokenizer() - vocab = tokenizer.get_vocab() - newline_1 = vocab["\n"] - newline_2 = vocab["\n\n"] - newline_3 = vocab["\n\n\n"] - newline_4 = vocab["\n\n\n\n"] - - def get_repl_toks(tok: int) -> list[int]: - if tok == newline_3: - return [newline_1, newline_2] - if tok == newline_4: - return [newline_2, newline_2] - - return [tok] - - repl_token_ids = list[int]() - repl_orig_idxs = list[int]() - for orig_idx, orig_tok in enumerate(new_token_ids): - repl_toks = get_repl_toks(orig_tok) - repl_token_ids.extend(repl_toks) - repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks))) - - repls = super()._find_mm_placeholders(repl_token_ids, mm_prompt_updates) - - return { - modality: [ - PlaceholderFeaturesInfo( - modality=p.modality, - item_idx=p.item_idx, - start_idx=repl_orig_idxs[p.start_idx], - tokens=p.tokens, - is_embed=p.is_embed, - ) - for p in placeholders - ] - for modality, placeholders in repls.items() - } - - -class Gemma3MultiModalProjector(nn.Module): - def __init__(self, config: Gemma3Config): - super().__init__() - - self.mm_input_projection_weight = nn.Parameter( - torch.zeros( - config.vision_config.hidden_size, config.text_config.hidden_size - ) - ) - - self.mm_soft_emb_norm = GemmaRMSNorm( - config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps - ) - - self.patches_per_image = int( - config.vision_config.image_size // config.vision_config.patch_size - ) - self.tokens_per_side = int(config.mm_tokens_per_image**0.5) - self.kernel_size = self.patches_per_image // self.tokens_per_side - self.avg_pool = nn.AvgPool2d( - kernel_size=self.kernel_size, stride=self.kernel_size - ) - - def forward(self, vision_outputs: torch.Tensor): - batch_size, _, seq_length = vision_outputs.shape - - reshaped_vision_outputs = vision_outputs.transpose(1, 2) - reshaped_vision_outputs = reshaped_vision_outputs.reshape( - batch_size, seq_length, self.patches_per_image, self.patches_per_image - ) - reshaped_vision_outputs = reshaped_vision_outputs.contiguous() - - pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs) - pooled_vision_outputs = pooled_vision_outputs.flatten(2) - pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2) - - normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs) - - projected_vision_outputs = torch.matmul( - normed_vision_outputs, self.mm_input_projection_weight - ) - return projected_vision_outputs.type_as(vision_outputs) - - -@MULTIMODAL_REGISTRY.register_processor( - Gemma3MultiModalProcessor, - info=Gemma3ProcessingInfo, - dummy_inputs=Gemma3DummyInputsBuilder, -) -class Gemma3ForConditionalGeneration( - nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA -): - merge_by_field_config = True - - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } - - hf_to_vllm_mapper = WeightsMapper( - orig_to_new_prefix={ - # mapping for new names in checkpoint saved after transformers v4.52 - "model.language_model.": "language_model.model.", - "model.vision_tower.": "vision_tower.", - "model.multi_modal_projector.": "multi_modal_projector.", - "lm_head.": "language_model.lm_head.", - } - ) - - @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> str | None: - if modality.startswith("image"): - return "" - - raise ValueError("Only image modality is supported") - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - multimodal_config = vllm_config.model_config.multimodal_config - self.config = config - self.quant_config = quant_config - self.multimodal_config = multimodal_config - - self.vision_tower = SiglipVisionModel( - config.vision_config, - quant_config, - prefix=maybe_prefix(prefix, "vision_tower"), - ) - self.multi_modal_projector = Gemma3MultiModalProjector(config) - - self.language_model = init_vllm_registered_model( - vllm_config=vllm_config, - hf_config=config.text_config, - prefix=maybe_prefix(prefix, "language_model"), - architectures=["Gemma3ForCausalLM"], - ) - logit_scale = getattr(config, "logit_scale", 1.0) - - if hasattr(self.language_model, "logits_processor"): - # The logits processor can be unset if we're using - # automatic conversion to pooling model. - self.language_model.logits_processor.scale *= logit_scale - - self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors - ) - - @property - def dtype(self): - return next(self.parameters()).dtype - - def _parse_and_validate_image_input( - self, **kwargs: object - ) -> Gemma3ImageInputs | None: - pixel_values = kwargs.pop("pixel_values", None) - num_patches = kwargs.pop("num_patches", None) - image_embeds = kwargs.pop("image_embeds", None) - assert image_embeds is None, "Gemma3 does not support image_embeds." - if pixel_values is None: - return None - - image_size = self.config.vision_config.image_size - - return Gemma3ImagePixelInputs( - pixel_values=pixel_values, - num_patches=num_patches, - resolve_bindings={"h": image_size, "w": image_size}, - ) - - def _image_pixels_to_features( - self, - vision_tower: SiglipVisionModel, - pixel_values: torch.Tensor, - ) -> torch.Tensor: - return vision_tower(pixel_values) - - def _process_image_input( - self, - image_input: Gemma3ImageInputs, - ) -> list[torch.Tensor]: - assert self.vision_tower is not None - - pixel_values = image_input["pixel_values"] - num_patches = image_input["num_patches"] - - image_features = self._image_pixels_to_features( - self.vision_tower, - pixel_values, - ) - image_embeds = self.multi_modal_projector(image_features) - - return [e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist())] - - def get_language_model(self) -> torch.nn.Module: - return self.language_model - - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: - image_input = self._parse_and_validate_image_input(**kwargs) - if image_input is None: - return [] - - return self._process_image_input(image_input) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: IntermediateTensors | None = None, - inputs_embeds: torch.Tensor | None = None, - **kwargs: object, - ) -> IntermediateTensors: - if intermediate_tensors is not None: - inputs_embeds = None - - hidden_states = self.language_model.model( - input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds, - **kwargs, - ) - - return hidden_states - - def prepare_attn_masks( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - mask_dtype: torch.dtype, - **kwargs, - ): - kwargs["has_images"] = True - # NOTE(woosuk): Here, we distinguish the sequences by the position id 0. - # This is a HACK. Fix this. - start_indices = (positions == 0).cpu().nonzero() - num_seqs = len(start_indices) - seq_lens = [] - for i in range(num_seqs): - start_idx = start_indices[i].item() - if i < num_seqs - 1: - end_idx = start_indices[i + 1].item() - else: - end_idx = len(input_ids) - seq_lens.append(end_idx - start_idx) - kwargs["seq_lens"] = seq_lens - - global_attn_masks = [] - local_attn_masks = [] - start_idx = 0 - for seq_len in seq_lens: - end_idx = start_idx + seq_len - input_token_ids = input_ids[start_idx:end_idx] - start_idx = end_idx - # Create a global causal mask. - global_attn_mask = torch.empty( - 1, - 1, - seq_len, - seq_len, - dtype=mask_dtype, - device=input_ids.device, - ) - global_attn_mask.fill_(float("-inf")) - # Fill the lower triangle with 0. - global_attn_mask = global_attn_mask.triu(diagonal=1) - - # Consider the bidirectional attention between image tokens. - img_mask = torch.zeros_like(global_attn_mask) - img_pos = input_token_ids == self.config.image_token_index - img_mask[:, :, :, img_pos] += 1 - img_mask[:, :, img_pos, :] += 1 - global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask) - global_attn_masks.append(global_attn_mask) - - sliding_window = self.config.text_config.sliding_window - if sliding_window is not None: - # Create a local causal mask with sliding window (1024). - local_attn_mask = torch.ones_like(global_attn_mask) - local_attn_mask = torch.tril(local_attn_mask, diagonal=-sliding_window) - local_attn_mask = torch.where( - local_attn_mask == 0, global_attn_mask, float("-inf") - ) - local_attn_masks.append(local_attn_mask) - kwargs["global_attn_masks"] = global_attn_masks - kwargs["local_attn_masks"] = local_attn_masks - return kwargs - - def compute_logits( - self, - hidden_states: torch.Tensor, - ) -> torch.Tensor | None: - return self.language_model.compute_logits(hidden_states) - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self) - return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) - - def get_mm_mapping(self) -> MultiModelKeys: - """ - Get the module prefix in multimodal models - """ - return MultiModelKeys.from_string_field( - language_model="language_model", - connector="multi_modal_projector", - tower_model="vision_tower", - ) diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py deleted file mode 100644 index fb0b4b290467..000000000000 --- a/vllm/model_executor/models/paligemma.py +++ /dev/null @@ -1,412 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Literal, TypeAlias - -import torch -from torch import nn -from transformers import BatchFeature, PaliGemmaConfig - -from vllm.config import VllmConfig -from vllm.config.multimodal import BaseDummyOptions -from vllm.logger import init_logger -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import ( - MultiModalDataDict, - MultiModalFieldConfig, - MultiModalInputs, - MultiModalKwargsItems, - MultiModalUUIDDict, -) -from vllm.multimodal.parse import ( - ImageEmbeddingItems, - ImageProcessorItems, - MultiModalDataItems, -) -from vllm.multimodal.processing import ( - BaseMultiModalProcessor, - BaseProcessingInfo, - PromptIndexTargets, - PromptInsertion, - PromptUpdate, - PromptUpdateDetails, -) -from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.sequence import IntermediateTensors -from vllm.utils.tensor_schema import TensorSchema, TensorShape - -from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP -from .siglip import SiglipVisionModel -from .utils import ( - AutoWeightsLoader, - WeightsMapper, - flatten_bn, - init_vllm_registered_model, - maybe_prefix, -) -from .vision import get_vision_encoder_info - -logger = init_logger(__name__) - - -class PaliGemmaImagePixelInputs(TensorSchema): - """ - Dimensions: - - bn: Batch size * number of images - - c: Number of channels (3) - - h: Height - - w: Width - """ - - type: Literal["pixel_values"] = "pixel_values" - data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] - - -class PaliGemmaImageEmbeddingInputs(TensorSchema): - """ - Dimensions: - - bn: Batch size * number of images - - ifs: Image feature size - - hs: Hidden size (must match language model backbone) - """ - - type: Literal["image_embeds"] = "image_embeds" - data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")] - - -PaliGemmaImageInputs: TypeAlias = ( - PaliGemmaImagePixelInputs | PaliGemmaImageEmbeddingInputs -) - - -class PaliGemmaMultiModalProjector(nn.Module): - def __init__(self, vision_hidden_size: int, projection_dim: int): - super().__init__() - - self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True) - - def forward(self, image_features: torch.Tensor) -> torch.Tensor: - hidden_states = self.linear(image_features) - return hidden_states - - -class PaliGemmaProcessingInfo(BaseProcessingInfo): - def get_hf_config(self): - return self.ctx.get_hf_config(PaliGemmaConfig) - - def get_vision_encoder_info(self): - return get_vision_encoder_info(self.get_hf_config()) - - def get_supported_mm_limits(self) -> Mapping[str, int | None]: - return {"image": 1} - - def get_num_image_tokens( - self, - *, - image_width: int, - image_height: int, - ) -> int: - vision_encoder_info = self.get_vision_encoder_info() - - return vision_encoder_info.get_num_image_tokens( - image_width=image_width, - image_height=image_height, - ) - - -class PaliGemmaDummyInputsBuilder(BaseDummyInputsBuilder[PaliGemmaProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: - return "" - - def get_dummy_mm_data( - self, - seq_len: int, - mm_counts: Mapping[str, int], - mm_options: Mapping[str, BaseDummyOptions] | None = None, - ) -> MultiModalDataDict: - hf_config = self.info.get_hf_config() - vision_config = hf_config.vision_config - max_image_size = vision_config.image_size - - num_images = mm_counts.get("image", 0) - - image_overrides = mm_options.get("image") if mm_options else None - - return { - "image": self._get_dummy_images( - width=max_image_size, - height=max_image_size, - num_images=num_images, - overrides=image_overrides, - ) - } - - -class PaliGemmaMultiModalProcessor(BaseMultiModalProcessor[PaliGemmaProcessingInfo]): - def _call_hf_processor( - self, - prompt: str, - mm_data: Mapping[str, object], - mm_kwargs: Mapping[str, object], - tok_kwargs: Mapping[str, object], - ) -> BatchFeature: - tokenizer = self.info.get_tokenizer() - if not mm_data: - prompt_ids = tokenizer.encode(prompt, add_special_tokens=False) - return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") - - return super()._call_hf_processor( - prompt=prompt, - mm_data=mm_data, - mm_kwargs=mm_kwargs, - tok_kwargs=tok_kwargs, - ) - - def _get_mm_fields_config( - self, - hf_inputs: BatchFeature, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> Mapping[str, MultiModalFieldConfig]: - return dict(pixel_values=MultiModalFieldConfig.batched("image")) - - def _get_prompt_updates( - self, - mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargsItems, - ) -> Sequence[PromptUpdate]: - hf_config = self.info.get_hf_config() - image_token_id = hf_config.image_token_index - - tokenizer = self.info.get_tokenizer() - - bos_token_id = tokenizer.bos_token_id - assert isinstance(bos_token_id, int) - - def get_insertion(item_idx: int): - images = mm_items.get_items( - "image", (ImageEmbeddingItems, ImageProcessorItems) - ) - - if isinstance(images, ImageEmbeddingItems): - num_image_tokens = images.get_feature_size(item_idx) - else: - image_size = images.get_image_size(item_idx) - num_image_tokens = self.info.get_num_image_tokens( - image_width=image_size.width, - image_height=image_size.height, - ) - - image_tokens = [image_token_id] * num_image_tokens - - return PromptUpdateDetails.select_token_id( - image_tokens + [bos_token_id], - embed_token_id=image_token_id, - ) - - # Paligemma 1 and 2 have different tokenizer.add_bos_token - # Insert *n + after for Paligemma 1 - # Insert *n + for Paligemma 2 - return [ - PromptInsertion( - modality="image", - target=PromptIndexTargets.prefix( - [bos_token_id] if tokenizer.add_bos_token else [] - ), - insertion=get_insertion, - ) - ] - - def apply( - self, - prompt: str | list[int], - mm_data: MultiModalDataDict, - hf_processor_mm_kwargs: Mapping[str, object], - tokenization_kwargs: Mapping[str, object] | None = None, - mm_uuids: MultiModalUUIDDict | None = None, - ) -> MultiModalInputs: - mm_inputs = super().apply( - prompt, - mm_data, - hf_processor_mm_kwargs, - tokenization_kwargs, - mm_uuids=mm_uuids, - ) - prompt_token_ids = mm_inputs["prompt_token_ids"] - - tokenizer = self.info.get_tokenizer() - newline_prompt = "\n" - newline_token_id = tokenizer.encode(newline_prompt)[-1] # 108 - # Force to add newline at the end of prompt for paligemma's format - # This step can NOT be replacemented by current PromptUpdate methods - if len(prompt_token_ids) and prompt_token_ids[-1] != newline_token_id: - prompt_token_ids.append(newline_token_id) - mm_inputs["prompt_token_ids"] = prompt_token_ids - - return mm_inputs - - -@MULTIMODAL_REGISTRY.register_processor( - PaliGemmaMultiModalProcessor, - info=PaliGemmaProcessingInfo, - dummy_inputs=PaliGemmaDummyInputsBuilder, -) -class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } - - hf_to_vllm_mapper = WeightsMapper( - orig_to_new_prefix={ - # mapping for new names in checkpoint saved after transformers v4.52 - "model.language_model.": "language_model.model.", - "model.vision_tower.": "vision_tower.", - "model.multi_modal_projector.": "multi_modal_projector.", - "lm_head.": "language_model.lm_head.", - } - ) - - @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> str | None: - if modality.startswith("image"): - return None - - raise ValueError("Only image modality is supported") - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - multimodal_config = vllm_config.model_config.multimodal_config - self.config = config - self.multimodal_config = multimodal_config - - self.vision_tower = SiglipVisionModel( - config.vision_config, - quant_config, - prefix=maybe_prefix(prefix, "vision_tower"), - ) - self.multi_modal_projector = PaliGemmaMultiModalProjector( - vision_hidden_size=config.vision_config.hidden_size, - projection_dim=config.vision_config.projection_dim, - ) - - self.quant_config = quant_config - - if config.text_config.model_type == "gemma": - config.text_config.architectures = ["GemmaForCausalLM"] - else: - config.text_config.architectures = ["Gemma2ForCausalLM"] - self.language_model = init_vllm_registered_model( - vllm_config=vllm_config, - hf_config=config.text_config, - prefix=maybe_prefix(prefix, "language_model"), - ) - logit_scale = getattr(config, "logit_scale", 1.0) - self.language_model.logits_processor.scale *= logit_scale - - self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors - ) - - def _parse_and_validate_image_input( - self, **kwargs: object - ) -> PaliGemmaImageInputs | None: - pixel_values = kwargs.pop("pixel_values", None) - image_embeds = kwargs.pop("image_embeds", None) - - if pixel_values is None and image_embeds is None: - return None - - if pixel_values is not None: - pixel_values = flatten_bn(pixel_values, concat=True) - - h = w = self.config.vision_config.image_size - return PaliGemmaImagePixelInputs( - type="pixel_values", - data=pixel_values, - resolve_bindings={"h": h, "w": w}, - ) - - if image_embeds is not None: - image_embeds = flatten_bn(image_embeds, concat=True) - - return PaliGemmaImageEmbeddingInputs( - type="image_embeds", - data=image_embeds, - ) - - raise AssertionError("This line should be unreachable.") - - def _image_pixels_to_features( - self, - vision_tower: SiglipVisionModel, - pixel_values: torch.Tensor, - ) -> torch.Tensor: - target_dtype = vision_tower.get_input_embeddings().weight.dtype - image_features = vision_tower(pixel_values.to(dtype=target_dtype)) - - return image_features - - def _process_image_input( - self, - image_input: PaliGemmaImageInputs, - ) -> torch.Tensor: - if image_input["type"] == "image_embeds": - return image_input["data"] - - assert self.vision_tower is not None - pixel_values = image_input["data"] - image_features = self._image_pixels_to_features( - self.vision_tower, - pixel_values, - ) - - return self.multi_modal_projector(image_features) - - def get_language_model(self) -> torch.nn.Module: - return self.language_model - - def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: - image_input = self._parse_and_validate_image_input(**kwargs) - if image_input is None: - return [] - vision_embeddings = self._process_image_input(image_input) - # https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa - vision_embeddings = vision_embeddings * (self.config.hidden_size**-0.5) - return vision_embeddings - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: IntermediateTensors | None = None, - inputs_embeds: torch.Tensor | None = None, - **kwargs: object, - ) -> IntermediateTensors: - if intermediate_tensors is not None: - inputs_embeds = None - - hidden_states = self.language_model.model( - input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds - ) - - return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - ) -> torch.Tensor | None: - return self.language_model.compute_logits(hidden_states) - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self) - return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 4171ebdbde6d..ea2f14916058 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -263,7 +263,6 @@ _MULTIMODAL_MODELS = { "Ernie4_5_VLMoeForConditionalGeneration", ), "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), - "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501 "Gemma3nForConditionalGeneration": ( "gemma3n_mm", "Gemma3nForConditionalGeneration", @@ -329,10 +328,6 @@ _MULTIMODAL_MODELS = { "NVLM_D": ("nvlm_d", "NVLM_D_Model"), "Ovis": ("ovis", "Ovis"), "Ovis2_5": ("ovis2_5", "Ovis2_5"), - "PaliGemmaForConditionalGeneration": ( - "paligemma", - "PaliGemmaForConditionalGeneration", - ), "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"), "Phi4MultimodalForCausalLM": ("phi4_multimodal", "Phi4MultimodalForCausalLM"), # noqa: E501 @@ -405,6 +400,14 @@ _TRANSFORMERS_SUPPORTED_MODELS = { "transformers", "TransformersMultiModalForCausalLM", ), + "Gemma3ForConditionalGeneration": ( + "transformers", + "TransformersMultiModalForCausalLM", + ), + "PaliGemmaForConditionalGeneration": ( + "transformers", + "TransformersMultiModalForCausalLM", + ), } _TRANSFORMERS_BACKEND_MODELS = { diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index b25b96889309..d1d94048c0c4 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -59,9 +59,6 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = { "Qwen2ForCausalLM": _ROCM_SWA_REASON, "MistralForCausalLM": _ROCM_SWA_REASON, "MixtralForCausalLM": _ROCM_SWA_REASON, - "PaliGemmaForConditionalGeneration": ( - "ROCm flash attention does not yet fully support 32-bit precision on PaliGemma" - ), "Phi3VForCausalLM": ( "ROCm Triton flash attention may run into compilation errors due to " "excessive use of shared memory. If this happens, disable Triton FA "