# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Mapping, Sequence from typing import Annotated, Literal, TypeAlias import torch from torch import nn from transformers import BatchFeature, PaliGemmaConfig from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalDataDict, MultiModalFieldConfig, MultiModalInputs, MultiModalKwargsItems, MultiModalUUIDDict, ) from vllm.multimodal.parse import ( ImageEmbeddingItems, ImageProcessorItems, MultiModalDataItems, ) from vllm.multimodal.processing import ( BaseMultiModalProcessor, BaseProcessingInfo, PromptIndexTargets, PromptInsertion, PromptUpdate, PromptUpdateDetails, ) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .siglip import SiglipVisionModel from .utils import ( AutoWeightsLoader, WeightsMapper, init_vllm_registered_model, maybe_prefix, ) from .vision import get_vision_encoder_info logger = init_logger(__name__) class PaliGemmaImagePixelInputs(TensorSchema): """ Dimensions: - bn: Batch size * number of images - c: Number of channels (3) - h: Height - w: Width """ type: Literal["pixel_values"] = "pixel_values" data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] class PaliGemmaImageEmbeddingInputs(TensorSchema): """ Dimensions: - bn: Batch size * number of images - ifs: Image feature size - hs: Hidden size (must match language model backbone) """ type: Literal["image_embeds"] = "image_embeds" data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")] PaliGemmaImageInputs: TypeAlias = ( PaliGemmaImagePixelInputs | PaliGemmaImageEmbeddingInputs ) class PaliGemmaMultiModalProjector(nn.Module): def __init__(self, vision_hidden_size: int, projection_dim: int): super().__init__() self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True) def forward(self, image_features: torch.Tensor) -> torch.Tensor: hidden_states = self.linear(image_features) return hidden_states class PaliGemmaProcessingInfo(BaseProcessingInfo): def get_hf_config(self): return self.ctx.get_hf_config(PaliGemmaConfig) def get_vision_encoder_info(self): return get_vision_encoder_info(self.get_hf_config()) def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": 1} def get_num_image_tokens( self, *, image_width: int, image_height: int, ) -> int: vision_encoder_info = self.get_vision_encoder_info() return vision_encoder_info.get_num_image_tokens( image_width=image_width, image_height=image_height, ) class PaliGemmaDummyInputsBuilder(BaseDummyInputsBuilder[PaliGemmaProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: return "" def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: hf_config = self.info.get_hf_config() vision_config = hf_config.vision_config max_image_size = vision_config.image_size num_images = mm_counts.get("image", 0) image_overrides = mm_options.get("image") if mm_options else None return { "image": self._get_dummy_images( width=max_image_size, height=max_image_size, num_images=num_images, overrides=image_overrides, ) } class PaliGemmaMultiModalProcessor(BaseMultiModalProcessor[PaliGemmaProcessingInfo]): def _call_hf_processor( self, prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object], ) -> BatchFeature: tokenizer = self.info.get_tokenizer() if not mm_data: prompt_ids = tokenizer.encode(prompt, add_special_tokens=False) return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") return super()._call_hf_processor( prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, tok_kwargs=tok_kwargs, ) def _get_mm_fields_config( self, hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: return dict(pixel_values=MultiModalFieldConfig.batched("image")) def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index tokenizer = self.info.get_tokenizer() bos_token_id = tokenizer.bos_token_id assert isinstance(bos_token_id, int) def get_insertion(item_idx: int): images = mm_items.get_items( "image", (ImageEmbeddingItems, ImageProcessorItems) ) if isinstance(images, ImageEmbeddingItems): num_image_tokens = images.get_feature_size(item_idx) else: image_size = images.get_image_size(item_idx) num_image_tokens = self.info.get_num_image_tokens( image_width=image_size.width, image_height=image_size.height, ) image_tokens = [image_token_id] * num_image_tokens return PromptUpdateDetails.select_token_id( image_tokens + [bos_token_id], embed_token_id=image_token_id, ) # Paligemma 1 and 2 have different tokenizer.add_bos_token # Insert *n + after for Paligemma 1 # Insert *n + for Paligemma 2 return [ PromptInsertion( modality="image", target=PromptIndexTargets.prefix( [bos_token_id] if tokenizer.add_bos_token else [] ), insertion=get_insertion, ) ] def apply( self, prompt: str | list[int], mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object] | None = None, mm_uuids: MultiModalUUIDDict | None = None, ) -> MultiModalInputs: mm_inputs = super().apply( prompt, mm_data, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids, ) prompt_token_ids = mm_inputs["prompt_token_ids"] tokenizer = self.info.get_tokenizer() newline_prompt = "\n" newline_token_id = tokenizer.encode(newline_prompt)[-1] # 108 # Force to add newline at the end of prompt for paligemma's format # This step can NOT be replacemented by current PromptUpdate methods if len(prompt_token_ids) and prompt_token_ids[-1] != newline_token_id: prompt_token_ids.append(newline_token_id) mm_inputs["prompt_token_ids"] = prompt_token_ids return mm_inputs @MULTIMODAL_REGISTRY.register_processor( PaliGemmaMultiModalProcessor, info=PaliGemmaProcessingInfo, dummy_inputs=PaliGemmaDummyInputsBuilder, ) class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", "k_proj", "v_proj", ], "gate_up_proj": [ "gate_proj", "up_proj", ], } hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ # mapping for new names in checkpoint saved after transformers v4.52 "model.language_model.": "language_model.model.", "model.vision_tower.": "vision_tower.", "model.multi_modal_projector.": "multi_modal_projector.", "lm_head.": "language_model.lm_head.", } ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return None raise ValueError("Only image modality is supported") def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config self.vision_tower = SiglipVisionModel( config.vision_config, quant_config, prefix=maybe_prefix(prefix, "vision_tower"), ) self.multi_modal_projector = PaliGemmaMultiModalProjector( vision_hidden_size=config.vision_config.hidden_size, projection_dim=config.vision_config.projection_dim, ) self.quant_config = quant_config if config.text_config.model_type == "gemma": config.text_config.architectures = ["GemmaForCausalLM"] else: config.text_config.architectures = ["Gemma2ForCausalLM"] self.language_model = init_vllm_registered_model( vllm_config=vllm_config, hf_config=config.text_config, prefix=maybe_prefix(prefix, "language_model"), ) logit_scale = getattr(config, "logit_scale", 1.0) self.language_model.logits_processor.scale *= logit_scale self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors ) def _parse_and_validate_image_input( self, **kwargs: object ) -> PaliGemmaImageInputs | None: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) if pixel_values is None and image_embeds is None: return None if pixel_values is not None: h = w = self.config.vision_config.image_size return PaliGemmaImagePixelInputs( type="pixel_values", data=pixel_values, resolve_bindings={"h": h, "w": w}, ) if image_embeds is not None: return PaliGemmaImageEmbeddingInputs( type="image_embeds", data=image_embeds, ) raise AssertionError("This line should be unreachable.") def _image_pixels_to_features( self, vision_tower: SiglipVisionModel, pixel_values: torch.Tensor, ) -> torch.Tensor: target_dtype = vision_tower.get_input_embeddings().weight.dtype image_features = vision_tower(pixel_values.to(dtype=target_dtype)) return image_features def _process_image_input( self, image_input: PaliGemmaImageInputs, ) -> torch.Tensor: if image_input["type"] == "image_embeds": return image_input["data"] assert self.vision_tower is not None pixel_values = image_input["data"] image_features = self._image_pixels_to_features( self.vision_tower, pixel_values, ) return self.multi_modal_projector(image_features) def get_language_model(self) -> torch.nn.Module: return self.language_model def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] vision_embeddings = self._process_image_input(image_input) # https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa vision_embeddings = vision_embeddings * (self.config.hidden_size**-0.5) return vision_embeddings def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, **kwargs: object, ) -> IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None hidden_states = self.language_model.model( input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)