# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import abstractmethod from collections.abc import Iterable, Mapping from typing import Annotated, Final, Literal, Protocol, TypeAlias, TypeVar import torch import torch.nn as nn from transformers import BatchFeature, LlavaNextConfig, LlavaNextProcessor from transformers.models.llava_next.modeling_llava_next import ( get_anyres_image_grid_shape, unpad_image, ) from vllm.config import VllmConfig from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalFieldConfig from vllm.multimodal.parse import ImageSize from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .llava import ( BaseLlavaMultiModalProcessor, BaseLlavaProcessingInfo, LlavaDummyInputsBuilder, LlavaLikeConfig, LlavaMultiModalProjector, init_vision_tower_for_llava, ) from .siglip import SiglipVisionModel from .utils import ( AutoWeightsLoader, WeightsMapper, init_vllm_registered_model, maybe_prefix, ) from .vision import get_num_selected_vision_tokens class LlavaNextImagePixelInputs(TensorSchema): """ Dimensions: - bn: Batch size * number of images - np: Number of patches + 1 - c: Number of channels (3) - h: Height - w: Width Note that `num_patches` may be different per batch and image, in which case the data is passed as a list instead of a batched tensor. """ type: Literal["pixel_values"] = "pixel_values" pixel_values: Annotated[ torch.Tensor | list[torch.Tensor], TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np"}), ] image_sizes: Annotated[torch.Tensor | None, TensorShape("bn", 2)] # This should be in `(height, width)` format. class LlavaNextImageEmbeddingInputs(TensorSchema): """ Dimensions: - bn: Batch size * number of images - ifs: Image feature size - hs: Hidden size (must match language model backbone) """ type: Literal["image_embeds"] = "image_embeds" data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")] LlavaNextImageInputs: TypeAlias = ( LlavaNextImagePixelInputs | LlavaNextImageEmbeddingInputs ) class LlavaNextLikeConfig(LlavaLikeConfig, Protocol): image_grid_pinpoints: Final[list[list[int]]] class LlavaNextProcessingInfo(BaseLlavaProcessingInfo): def get_hf_config(self) -> LlavaNextLikeConfig: return self.ctx.get_hf_config(LlavaNextConfig) def get_hf_processor(self, **kwargs: object): hf_processor = self.ctx.get_hf_processor(LlavaNextProcessor, **kwargs) # In case patch_size is omitted from `processor_config.json` # e.g. for E5-V: https://huggingface.co/royokong/e5-v if hf_processor.patch_size is None: patch_size = self.get_vision_encoder_info().get_patch_size() hf_processor.patch_size = patch_size return hf_processor # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L113 def get_num_image_tokens( self, *, image_width: int, image_height: int, ) -> int: hf_config = self.get_hf_config() vision_encoder_info = self.get_vision_encoder_info() base_feature_size = get_num_selected_vision_tokens( vision_encoder_info.get_num_image_tokens( image_width=image_width, image_height=image_height, ), hf_config.vision_feature_select_strategy, ) num_patch_height, num_patch_width = get_anyres_image_grid_shape( image_size=(image_height, image_width), grid_pinpoints=hf_config.image_grid_pinpoints, patch_size=vision_encoder_info.get_image_size(), ) ( unpadded_feature_size, newline_feature_size, ) = self._get_num_unpadded_features( original_height=image_height, original_width=image_width, npatches=vision_encoder_info.get_patch_grid_length(), num_patch_height=num_patch_height, num_patch_width=num_patch_width, ) return unpadded_feature_size + newline_feature_size + base_feature_size # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86 def _get_num_unpadded_features( self, *, original_height: int, original_width: int, npatches: int, num_patch_height: int, num_patch_width: int, ) -> tuple[int, int]: current_height = npatches * num_patch_height current_width = npatches * num_patch_width aspect_ratio = original_width / original_height current_aspect_ratio = current_width / current_height if aspect_ratio > current_aspect_ratio: new_height = int( round(original_height * (current_width / original_width), 7) ) padding = (current_height - new_height) // 2 current_height = current_height - (2 * padding) else: new_width = int( round(original_width * (current_height / original_height), 7) ) padding = (current_width - new_width) // 2 current_width = current_width - (2 * padding) unpadded_features = current_height * current_width newline_features = current_height return (unpadded_features, newline_features) def get_image_size_with_most_features(self) -> ImageSize: hf_config = self.get_hf_config() largest_feature_size, largest_feature_pinpoint = 0, None for height, width in hf_config.image_grid_pinpoints: feat_size = self.get_num_image_tokens( image_width=width, image_height=height ) if feat_size > largest_feature_size: largest_feature_size = feat_size largest_feature_pinpoint = ImageSize(width=width, height=height) if largest_feature_size == 0 or largest_feature_pinpoint is None: raise ValueError("Cannot have a largest feature size of 0!") return largest_feature_pinpoint _I = TypeVar("_I", bound=LlavaNextProcessingInfo) class BaseLlavaNextMultiModalProcessor(BaseLlavaMultiModalProcessor[_I]): # Copied from BaseMultiModalProcessor @abstractmethod def _get_mm_fields_config( self, hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: raise NotImplementedError class LlavaNextMultiModalProcessor( BaseLlavaNextMultiModalProcessor[LlavaNextProcessingInfo] ): def _get_mm_fields_config( self, hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: return dict( pixel_values=MultiModalFieldConfig.batched("image"), image_sizes=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"), ) @MULTIMODAL_REGISTRY.register_processor( LlavaNextMultiModalProcessor, info=LlavaNextProcessingInfo, dummy_inputs=LlavaDummyInputsBuilder, ) class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): merge_by_field_config = True hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ # mapping for new names in checkpoint saved after transformers v4.52 "model.language_model.": "language_model.model.", "model.vision_tower.": "vision_tower.", "model.multi_modal_projector.": "multi_modal_projector.", "model.image_newline": "image_newline", "lm_head.": "language_model.lm_head.", } ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "" raise ValueError("Only image modality is supported") def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config vision_feature_layer = config.vision_feature_layer # Determine the layer up to which we will initialize the vision tower if isinstance(vision_feature_layer, int): vision_hidden_size = config.vision_config.hidden_size self.select_layers = None # Used for multimodal granite models to control encoder outputs elif isinstance(vision_feature_layer, (list, tuple)): vision_hidden_size = config.vision_config.hidden_size * len( vision_feature_layer ) self.select_layers = vision_feature_layer else: raise TypeError( f"vision_layer_feature type: {type(vision_feature_layer)}" " is not supported" ) self.config = config self.multimodal_config = multimodal_config # TODO: Optionally initializes this for supporting embeddings. self.vision_tower = init_vision_tower_for_llava( config, quant_config, require_post_norm=False, prefix=maybe_prefix(prefix, "vision_tower"), ) self.image_newline = nn.Parameter(torch.empty(config.text_config.hidden_size)) self.multi_modal_projector = LlavaMultiModalProjector( vision_hidden_size=vision_hidden_size, text_hidden_size=config.text_config.hidden_size, projector_hidden_act=config.projector_hidden_act, multimodal_projector_bias=config.multimodal_projector_bias, ) self.language_model = init_vllm_registered_model( vllm_config=vllm_config, hf_config=config.text_config, prefix=maybe_prefix(prefix, "language_model"), ) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors ) def _parse_and_validate_image_input( self, **kwargs: object ) -> LlavaNextImageInputs | None: pixel_values = kwargs.pop("pixel_values", None) image_sizes = kwargs.pop("image_sizes", None) image_embeds = kwargs.pop("image_embeds", None) if pixel_values is None and image_embeds is None: return None if pixel_values is not None: expected_h = expected_w = self.config.vision_config.image_size return LlavaNextImagePixelInputs( type="pixel_values", pixel_values=pixel_values, image_sizes=image_sizes, resolve_bindings={ "h": expected_h, "w": expected_w, }, ) if image_embeds is not None: return LlavaNextImageEmbeddingInputs( type="image_embeds", data=image_embeds, ) raise AssertionError("This line should be unreachable.") def _image_pixels_to_features( self, vision_tower: CLIPVisionModel | SiglipVisionModel, pixel_values: torch.Tensor, ) -> torch.Tensor: # NOTE: we skip the step to select the vision feature layer since # this is already done inside the vision tower return vision_tower( pixel_values, select_layers=self.select_layers, feature_select_strategy=self.config.vision_feature_select_strategy, ) # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py def _merge_image_patch_embeddings( self, image_size: torch.Tensor, patch_embeddings: torch.Tensor, *, strategy: str ) -> torch.Tensor: if strategy == "flat": return patch_embeddings.flatten(0, 1) if strategy.startswith("spatial"): height = width = ( self.config.vision_config.image_size // self.config.vision_config.patch_size ) base_patch_embeds = patch_embeddings[0] if height * width != base_patch_embeds.shape[0]: raise ValueError( "The number of patches is not consistent with the image size." ) if patch_embeddings.shape[0] > 1: other_patch_embeds = patch_embeddings[1:] # Move to CPU to avoid floating-point errors orig_height, orig_width = image_size.tolist() # image_aspect_ratio == "anyres" num_patch_height, num_patch_width = get_anyres_image_grid_shape( (orig_height, orig_width), self.config.image_grid_pinpoints, self.config.vision_config.image_size, ) num_patches = num_patch_height * num_patch_width # Image patches might be padded for batch processing other_patch_embeds = other_patch_embeds[:num_patches].view( num_patch_height, num_patch_width, height, width, -1 ) if "unpad" in strategy: other_patch_embeds = ( other_patch_embeds.permute(4, 0, 2, 1, 3) .contiguous() .flatten(1, 2) .flatten(2, 3) ) other_patch_embeds = unpad_image( other_patch_embeds, (orig_height, orig_width) ) other_patch_embeds = torch.cat( ( other_patch_embeds, self.image_newline[:, None, None] .expand(*other_patch_embeds.shape[:-1], 1) .to(other_patch_embeds.device), ), dim=-1, ) other_patch_embeds = other_patch_embeds.flatten(1, 2).transpose( 0, 1 ) else: other_patch_embeds = ( other_patch_embeds.permute(0, 2, 1, 3, 4) .contiguous() .flatten(0, 3) ) merged_patch_embeddings = torch.cat( (base_patch_embeds, other_patch_embeds), dim=0 ) else: if "unpad" in strategy: merged_patch_embeddings = torch.cat( ( base_patch_embeds, self.image_newline[None].to(base_patch_embeds.device), ), dim=0, ) else: merged_patch_embeddings = base_patch_embeds return merged_patch_embeddings raise ValueError(f"Unexpected patch merge strategy: {strategy}") def _process_image_pixels( self, inputs: LlavaNextImagePixelInputs, ) -> torch.Tensor | tuple[torch.Tensor, ...]: assert self.vision_tower is not None pixel_values = inputs["pixel_values"] if isinstance(pixel_values, torch.Tensor): b, num_patches, c, h, w = pixel_values.shape stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w) stacked_image_features = self._image_pixels_to_features( self.vision_tower, stacked_pixel_values ) stacked_patch_embeddings = self.multi_modal_projector( stacked_image_features ) return stacked_patch_embeddings.view( b, num_patches, *stacked_patch_embeddings.shape[1:] ) num_patches_per_batch = [v.shape[0] for v in pixel_values] stacked_pixel_values = torch.cat(pixel_values) stacked_image_features = self._image_pixels_to_features( self.vision_tower, stacked_pixel_values ) return torch.split( self.multi_modal_projector(stacked_image_features), num_patches_per_batch ) def _process_image_input( self, image_input: LlavaNextImageInputs, ) -> torch.Tensor | list[torch.Tensor]: if image_input["type"] == "image_embeds": return image_input["data"] patch_embeddings = self._process_image_pixels(image_input) image_sizes = image_input.get("image_sizes") if image_sizes is None: batch_size = len(image_input["data"]) vision_config = self.config.vision_config default_height = default_width = vision_config.image_size image_sizes = torch.as_tensor( [[default_height, default_width] for _ in range(batch_size)] ) return [ self._merge_image_patch_embeddings( image_sizes[i], patch_features_batch, strategy="spatial_unpad" ) for i, patch_features_batch in enumerate(patch_embeddings) ] def get_language_model(self) -> torch.nn.Module: return self.language_model def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] vision_embeddings = self._process_image_input(image_input) return vision_embeddings def embed_input_ids( self, input_ids: torch.Tensor, multimodal_embeddings: MultiModalEmbeddings | None = None, *, is_multimodal: torch.Tensor | None = None, # Multi-modal token ID may exceed vocab size handle_oov_mm_token: bool = True, ) -> torch.Tensor: # This is to satisfy the type checker for each overload if multimodal_embeddings is None or is_multimodal is None: return super().embed_input_ids(input_ids) return super().embed_input_ids( input_ids, multimodal_embeddings=multimodal_embeddings, is_multimodal=is_multimodal, handle_oov_mm_token=handle_oov_mm_token, ) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, **kwargs: object, ) -> torch.Tensor | IntermediateTensors: """Run forward pass for LlaVA-NeXT. One key thing to understand is the `input_ids` already accounts for the positions of the to-be-inserted image embeddings. Concretely, consider a text prompt: `"A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: \\nWhat is shown in this image? ASSISTANT:"`. Tokenizer outputs: `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255, 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568, 6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901, 29871, 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799, 9047, 13566, 29901]`. To reserve space in KV cache, we have to insert placeholder tokens before they are inputted to the model, so the input processor prepends additional image tokens (denoted as `32000`), resulting in: `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255, 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568, 6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901, 29871, 32000, ..., 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799, 9047, 13566, 29901]`. Unlike in LLaVA-1.5, the number of image tokens inputted to the language model depends on the original size of the input image. Including the original image token in the input, the required number of image tokens is given by [`LlavaNextProcessingInfo.get_num_image_tokens`][vllm.\ model_executor.models.llava_next.LlavaNextProcessingInfo.get_num_image_tokens]. This way, the `positions` and `attn_metadata` are consistent with the `input_ids`. Args: input_ids: Flattened (concatenated) input_ids corresponding to a batch. positions: Position indices for the input tokens. intermediate_tensors: Intermediate tensors from prior forward pass. inputs_embeds: Optional tensor of input embeddings. Info: [`LlavaNextImageInputs`][vllm.model_executor.models.llava_next.LlavaNextImageInputs] """ if intermediate_tensors is not None: inputs_embeds = None hidden_states = self.language_model.model( input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)