diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md
index fc363585b0e7..ca2c4d35d771 100644
--- a/docs/source/models/supported_models.md
+++ b/docs/source/models/supported_models.md
@@ -842,13 +842,13 @@ See [this page](#generative-models) for more information on how to use generativ
*
* ✅︎
* ✅︎
-- * `PaliGemmaForConditionalGeneration`\*
- * PaliGemma, PaliGemma 2
+- * `PaliGemmaForConditionalGeneration`
+ * PaliGemma (see note), PaliGemma 2 (see note)
* T + IE
* `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc.
*
* ✅︎
- *
+ * ✅︎
- * `Phi3VForCausalLM`
* Phi-3-Vision, Phi-3.5-Vision
* T + IE+
diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py
index 3f7a7c01aebc..2540933bbc23 100644
--- a/tests/models/decoder_only/vision_language/test_models.py
+++ b/tests/models/decoder_only/vision_language/test_models.py
@@ -116,9 +116,8 @@ VLM_TEST_SETTINGS = {
"pixel_values"
),
vllm_output_post_proc=model_utils.paligemma_vllm_to_hf_output,
- dtype=("half" if current_platform.is_cpu() or current_platform.is_rocm()
- else ("half", "float")),
- marks=[pytest.mark.core_model],
+ dtype="bfloat16",
+ marks=[pytest.mark.skip(reason="vLLM does not support PrefixLM attention mask")], # noqa: E501
),
# TODO(ywang96): Move Qwen2-VL out of core models in favor of Qwen2.5-VL
# once we upgraded to transformers>=4.49.0.
diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py
index 7534f0c97798..629d1012d18e 100644
--- a/tests/models/multimodal/processing/test_common.py
+++ b/tests/models/multimodal/processing/test_common.py
@@ -175,6 +175,8 @@ def _test_processing_correctness(
"Qwen/Qwen2-Audio-7B-Instruct",
"fixie-ai/ultravox-v0_4",
"openai/whisper-large-v3",
+ "google/paligemma-3b-mix-224",
+ "google/paligemma2-3b-ft-docci-448",
])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])
diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py
index 0e39389eb633..f3dc87854cba 100644
--- a/vllm/model_executor/models/paligemma.py
+++ b/vllm/model_executor/models/paligemma.py
@@ -5,22 +5,26 @@ from typing import (Iterable, Literal, Mapping, Optional, Set, Tuple,
import torch
from torch import nn
-from transformers import PaliGemmaConfig
+from transformers import BatchFeature, PaliGemmaConfig
from vllm.config import VllmConfig
-from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
- InputContext, token_inputs)
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
-from vllm.multimodal.inputs import NestedTensors
+from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
+ MultiModalInputs, MultiModalKwargs,
+ NestedTensors)
+from vllm.multimodal.parse import MultiModalDataItems
+from vllm.multimodal.processing import (BaseMultiModalProcessor,
+ BaseProcessingInfo, PromptIndexTargets,
+ PromptInsertion, PromptReplacement,
+ PromptUpdateDetails)
+from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
-from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
-from .interfaces import SupportsMultiModal, SupportsPP, SupportsV0Only
-from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
- dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
+from .interfaces import SupportsMultiModal, SupportsPP
+from .siglip import SiglipVisionModel, get_max_siglip_image_tokens
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
@@ -46,79 +50,6 @@ PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs,
PaliGemmaImageEmbeddingInputs]
-def get_max_paligemma_image_tokens(ctx: InputContext):
- hf_config = ctx.get_hf_config(PaliGemmaConfig)
- vision_config = hf_config.vision_config
-
- return get_max_siglip_image_tokens(vision_config)
-
-
-def dummy_data_for_paligemma(ctx: InputContext, seq_len: int,
- mm_counts: Mapping[str, int]):
- hf_config = ctx.get_hf_config(PaliGemmaConfig)
- vision_config = hf_config.vision_config
- num_images = mm_counts["image"]
-
- seq_data, ranges = dummy_seq_data_for_siglip(
- vision_config,
- seq_len,
- num_images,
- image_token_id=hf_config.image_token_index,
- )
-
- mm_data = dummy_image_for_siglip(vision_config, num_images)
- return DummyData(seq_data, mm_data, ranges)
-
-
-def input_processor_for_paligemma(ctx: InputContext,
- inputs: DecoderOnlyInputs):
-
- """
- The correct prompt format needs to be:
- '' * image_feature_size + '' + prompt + '\n'
-
- See https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/paligemma/processing_paligemma.py#L55
- """ # noqa
-
- multi_modal_data = inputs.get("multi_modal_data")
- if multi_modal_data is None or "image" not in multi_modal_data:
- return inputs
-
- model_config = ctx.model_config
- hf_config = ctx.get_hf_config(PaliGemmaConfig)
-
- tokenizer = cached_tokenizer_from_config(model_config)
- image_feature_size = hf_config.text_config.num_image_tokens
- image_token_str = tokenizer.decode(hf_config.image_token_index)
- bos_token = tokenizer.decode(hf_config.bos_token_id)
- image_token_str_pad = image_token_str * image_feature_size
- image_token_ids_pad = [hf_config.image_token_index] * image_feature_size
-
- orig_prompt = inputs.get("prompt")
- orig_prompt_ids = inputs.get("prompt_token_ids")
-
- if orig_prompt is not None and image_token_str in orig_prompt:
- logger.warning(
- "The image token '%s' was detected in the prompt and "
- "will be removed. Please follow the proper prompt format"
- " documented on HuggingFace.", image_token_str)
- orig_prompt = orig_prompt.replace(image_token_str, "")
- orig_prompt_ids.remove(hf_config.image_token_index)
-
- new_prompt = f"{image_token_str_pad}{bos_token}{orig_prompt}\n"
-
- # The PaliGemma 2 tokenizer does not include a starting BOS token
- if orig_prompt_ids[0] != hf_config.bos_token_id:
- orig_prompt_ids = [hf_config.bos_token_id] + orig_prompt_ids
-
- new_token_ids = image_token_ids_pad + orig_prompt_ids + [108] #newline
-
- # NOTE: Create a defensive copy of the original inputs
- return token_inputs(prompt_token_ids=new_token_ids,
- prompt=new_prompt,
- multi_modal_data=multi_modal_data)
-
-
class PaliGemmaMultiModalProjector(nn.Module):
def __init__(self, vision_hidden_size: int, projection_dim: int):
@@ -131,12 +62,140 @@ class PaliGemmaMultiModalProjector(nn.Module):
return hidden_states
-@MULTIMODAL_REGISTRY.register_image_input_mapper()
-@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens)
-@INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma)
-@INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma)
+class PaliGemmaProcessingInfo(BaseProcessingInfo):
+
+ def get_hf_config(self):
+ return self.ctx.get_hf_config(PaliGemmaConfig)
+
+ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
+ return {"image": 1}
+
+ def get_mm_max_tokens_per_item(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> Mapping[str, int]:
+ return {"image": self.get_num_image_tokens()}
+
+ def get_num_image_tokens(self) -> int:
+ hf_config = self.get_hf_config()
+ vision_config = hf_config.vision_config
+ return get_max_siglip_image_tokens(vision_config)
+
+
+class PaliGemmaDummyInputsBuilder(
+ BaseDummyInputsBuilder[PaliGemmaProcessingInfo]):
+
+ def get_dummy_processor_inputs(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> ProcessorInputs:
+ hf_config = self.info.get_hf_config()
+ vision_config = hf_config.vision_config
+ max_image_size = vision_config.image_size
+
+ num_images = mm_counts.get("image", 0)
+
+ mm_data = {
+ "image":
+ self._get_dummy_images(width=max_image_size,
+ height=max_image_size,
+ num_images=num_images)
+ }
+
+ return ProcessorInputs(
+ prompt_text="",
+ mm_data=mm_data,
+ )
+
+
+class PaliGemmaMultiModalProcessor(
+ BaseMultiModalProcessor[PaliGemmaProcessingInfo]):
+
+ def _call_hf_processor(
+ self,
+ prompt: str,
+ mm_data: Mapping[str, object],
+ mm_kwargs: Mapping[str, object],
+ ) -> BatchFeature:
+ tokenizer = self.info.get_tokenizer()
+ if not mm_data:
+ prompt_ids = tokenizer.encode(prompt)
+ return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
+
+ return super()._call_hf_processor(
+ prompt=prompt,
+ mm_data=mm_data,
+ mm_kwargs=mm_kwargs,
+ )
+
+ def _get_mm_fields_config(
+ self,
+ hf_inputs: BatchFeature,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> Mapping[str, MultiModalFieldConfig]:
+ return dict(pixel_values=MultiModalFieldConfig.batched("image"))
+
+ def _get_prompt_updates(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ out_mm_kwargs: MultiModalKwargs,
+ ) -> list[PromptReplacement]:
+ hf_config = self.info.get_hf_config()
+ image_token_id = hf_config.image_token_index
+
+ tokenizer = self.info.get_tokenizer()
+ num_image_tokens = self.info.get_num_image_tokens()
+ image_tokens = [image_token_id] * num_image_tokens
+
+ bos_token_id = tokenizer.bos_token_id
+ assert isinstance(bos_token_id, int)
+
+ # Paligemma 1 and 2 have different tokenizer.add_bos_token
+ # Insert *n + after for Paligemma 1
+ # Insert *n + for Paligemma 2
+ return [
+ PromptInsertion(
+ modality="image",
+ target=PromptIndexTargets.prefix(
+ [bos_token_id] if tokenizer.add_bos_token else []),
+ insertion=PromptUpdateDetails(
+ full=image_tokens + [bos_token_id],
+ features=image_tokens,
+ ),
+ )
+ ]
+
+ def apply(
+ self,
+ prompt: Union[str, list[int]],
+ mm_data: MultiModalDataDict,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> MultiModalInputs:
+ mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
+ prompt_token_ids = mm_inputs["prompt_token_ids"]
+
+ tokenizer = self.info.get_tokenizer()
+ newline_prompt = "\n"
+ newline_token_id = tokenizer.encode(newline_prompt)[-1] # 108
+ # Force to add newline at the end of prompt for paligemma's format
+ # This step can NOT be replacemented by current PromptUpdate methods
+ if len(prompt_token_ids) and prompt_token_ids[-1] != newline_token_id:
+ prompt_token_ids.append(newline_token_id)
+ mm_inputs["prompt_token_ids"] = prompt_token_ids
+ mm_inputs["prompt"] += newline_prompt
+
+ return mm_inputs
+
+
+@MULTIMODAL_REGISTRY.register_processor(
+ PaliGemmaMultiModalProcessor,
+ info=PaliGemmaProcessingInfo,
+ dummy_inputs=PaliGemmaDummyInputsBuilder)
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
- SupportsPP, SupportsV0Only):
+ SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",