from functools import cached_property from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) import torch import torch.nn as nn from transformers import BatchFeature, LlavaNextConfig, LlavaNextProcessor from transformers.models.llava_next.modeling_llava_next import ( get_anyres_image_grid_shape, unpad_image) from typing_extensions import NotRequired from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors from vllm.multimodal.parse import ImageSize from vllm.sequence import IntermediateTensors from .clip import CLIPVisionModel from .interfaces import SupportsMultiModal, SupportsPP from .llava import (LlavaMultiModalProcessor, LlavaMultiModalProjector, init_vision_tower_for_llava) from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn, init_vllm_registered_model, maybe_prefix) class LlavaNextImagePixelInputs(TypedDict): type: Literal["pixel_values"] data: Union[torch.Tensor, List[torch.Tensor]] """ Shape: `(batch_size * num_images, 1 + num_patches, num_channels, height, width)` Note that `num_patches` may be different per batch and image, in which case the data is passed as a list instead of a batched tensor. """ image_sizes: NotRequired[torch.Tensor] """ Shape: `(batch_size * num_images, 2)` This should be in `(height, width)` format. """ class LlavaNextImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] data: torch.Tensor """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` `hidden_size` must match the hidden size of language model backbone. """ LlavaNextImageInputs = Union[LlavaNextImagePixelInputs, LlavaNextImageEmbeddingInputs] class LlavaNextMultiModalProcessor(LlavaMultiModalProcessor): def _get_hf_config(self) -> LlavaNextConfig: return self.ctx.get_hf_config(LlavaNextConfig) def _get_hf_processor(self) -> LlavaNextProcessor: return self.ctx.get_hf_processor(LlavaNextProcessor) def _get_image_token(self) -> str: return self._get_hf_processor().image_token def _get_mm_fields_config( self, hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: return dict( pixel_values=MultiModalFieldConfig.batched("image"), image_sizes=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"), ) def _get_max_image_tokens(self) -> int: largest_feature_size, _ = self._get_pinpoint_with_most_features() return largest_feature_size def _get_dummy_image_size(self) -> ImageSize: _, pinpoint = self._get_pinpoint_with_most_features() return pinpoint # Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106 def _get_num_image_tokens( self, *, image_width: int, image_height: int, ) -> int: hf_config = self._get_hf_config() base_feature_size = self._apply_feature_select_strategy( hf_config.vision_feature_select_strategy, self._vision_encoder_info.get_num_image_tokens( image_width=image_width, image_height=image_height, ), ) num_patches = self._vision_encoder_info.get_num_patches() num_patch_height, num_patch_width = get_anyres_image_grid_shape( image_size=(image_height, image_width), grid_pinpoints=hf_config.image_grid_pinpoints, patch_size=self._vision_encoder_info.get_image_size(), ) ( unpadded_feature_size, newline_feature_size, ) = self._get_num_unpadded_features( original_height=image_height, original_width=image_width, npatches=num_patches, num_patch_height=num_patch_height, num_patch_width=num_patch_width, ) return unpadded_feature_size + newline_feature_size + base_feature_size # Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79 def _get_num_unpadded_features( self, *, original_height: int, original_width: int, npatches: int, num_patch_height: int, num_patch_width: int, ) -> tuple[int, int]: current_height = npatches * num_patch_height current_width = npatches * num_patch_width original_aspect_ratio = original_width / original_height current_aspect_ratio = current_width / current_height if original_aspect_ratio > current_aspect_ratio: scale_factor = current_width / original_width new_height = int(original_height * scale_factor) padding = (current_height - new_height) // 2 current_height -= 2 * padding else: scale_factor = current_height / original_height new_width = int(original_width * scale_factor) padding = (current_width - new_width) // 2 current_width -= 2 * padding unpadded_features = current_height * current_width newline_features = current_height return (unpadded_features, newline_features) def _get_pinpoint_with_most_features(self) -> tuple[int, ImageSize]: """ Get the grid pinpoint with the most features and the corresponding feature size. """ hf_config = self._get_hf_config() largest_feature_size, largest_feature_pinpoint = 0, None for (height, width) in hf_config.image_grid_pinpoints: feat_size = self._get_num_image_tokens(image_width=width, image_height=height) if feat_size > largest_feature_size: largest_feature_size = feat_size largest_feature_pinpoint = ImageSize(width=width, height=height) if largest_feature_size == 0 or largest_feature_pinpoint is None: raise ValueError("Cannot have a largest feature size of 0!") return largest_feature_size, largest_feature_pinpoint @MULTIMODAL_REGISTRY.register_processor(LlavaNextMultiModalProcessor) class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config vision_feature_layer = config.vision_feature_layer # Determine the layer up to which we will initialize the vision tower if isinstance(vision_feature_layer, int): vision_hidden_size = config.vision_config.hidden_size self.feature_sample_layers = None # Used for multimodal granite models to control encoder outputs elif isinstance(vision_feature_layer, (list, tuple)): vision_hidden_size = config.vision_config.hidden_size * len( vision_feature_layer) self.feature_sample_layers = vision_feature_layer else: raise TypeError( f"vision_layer_feature type: {type(vision_feature_layer)}" " is not supported") self.config = config self.multimodal_config = multimodal_config # TODO: Optionally initializes this for supporting embeddings. self.vision_tower = init_vision_tower_for_llava( config, quant_config, require_post_norm=False, prefix=maybe_prefix(prefix, "vision_tower")) self.image_newline = nn.Parameter( torch.empty(config.text_config.hidden_size)) self.multi_modal_projector = LlavaMultiModalProjector( vision_hidden_size=vision_hidden_size, text_hidden_size=config.text_config.hidden_size, projector_hidden_act=config.projector_hidden_act) self.language_model = init_vllm_registered_model( vllm_config=vllm_config, hf_config=config.text_config, prefix=maybe_prefix(prefix, "language_model"), ) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) @cached_property def sampler(self): if hasattr(self.language_model, "sampler"): return self.language_model.sampler return get_sampler() def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: expected_dims = (2, ) def _validate_shape(d: torch.Tensor): actual_dims = tuple(d.shape) if actual_dims != expected_dims: expected_expr = str(expected_dims) raise ValueError( f"The expected shape of image sizes per image per batch " f"is {expected_expr}. You supplied {tuple(d.shape)}.") for d in data: _validate_shape(d) return data def _validate_pixel_values( self, data: Union[torch.Tensor, List[torch.Tensor]] ) -> Union[torch.Tensor, List[torch.Tensor]]: h = w = self.config.vision_config.image_size expected_dims = (3, h, w) def _validate_shape(d: torch.Tensor): actual_dims = tuple(d.shape[1:]) if actual_dims != expected_dims: expected_expr = ("num_patches", *map(str, expected_dims)) raise ValueError( "The expected shape of pixel values per image per batch " f"is {expected_expr}. You supplied {tuple(d.shape)}.") for d in data: _validate_shape(d) return data def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[LlavaNextImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_sizes = kwargs.pop("image_sizes", None) image_embeds = kwargs.pop("image_embeds", None) if pixel_values is None and image_embeds is None: return None if pixel_values is not None: if not isinstance(pixel_values, (torch.Tensor, list)): raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") if not isinstance(image_sizes, (torch.Tensor, list)): raise ValueError("Incorrect type of image sizes. " f"Got type: {type(image_sizes)}") return LlavaNextImagePixelInputs( type="pixel_values", data=self._validate_pixel_values(flatten_bn(pixel_values)), image_sizes=self._validate_image_sizes( flatten_bn(image_sizes, concat=True)), ) if image_embeds is not None: if not isinstance(image_embeds, torch.Tensor): raise ValueError("Incorrect type of image embeds. " f"Got type: {type(image_embeds)}") return LlavaNextImageEmbeddingInputs( type="image_embeds", data=flatten_bn(image_embeds), ) raise AssertionError("This line should be unreachable.") def _select_image_features(self, image_features: torch.Tensor, *, strategy: str) -> torch.Tensor: # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa if strategy == "default": return image_features[:, 1:] elif strategy == "full": return image_features raise ValueError(f"Unexpected select feature strategy: {strategy}") def _image_pixels_to_features( self, vision_tower: Union[CLIPVisionModel, SiglipVisionModel], pixel_values: torch.Tensor, ) -> torch.Tensor: # NOTE: we skip the step to select the vision feature layer since # this is already done inside the vision tower image_features = vision_tower( pixel_values, feature_sample_layers=self.feature_sample_layers) return self._select_image_features( image_features, strategy=self.config.vision_feature_select_strategy, ) # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py def _merge_image_patch_embeddings(self, image_size: torch.Tensor, patch_embeddings: torch.Tensor, *, strategy: str) -> torch.Tensor: if strategy == "flat": return patch_embeddings.flatten(0, 1) if strategy.startswith("spatial"): height = width = self.config.vision_config.image_size \ // self.config.vision_config.patch_size base_patch_embeds = patch_embeddings[0] if height * width != base_patch_embeds.shape[0]: raise ValueError( "The number of patches is not consistent with the " "image size.") if patch_embeddings.shape[0] > 1: other_patch_embeds = patch_embeddings[1:] # Move to CPU to avoid floating-point errors orig_height, orig_width = image_size.tolist() # image_aspect_ratio == "anyres" num_patch_height, num_patch_width = get_anyres_image_grid_shape( (orig_height, orig_width), self.config.image_grid_pinpoints, self.config.vision_config.image_size, ) num_patches = num_patch_height * num_patch_width # Image patches might be padded for batch processing other_patch_embeds = other_patch_embeds[:num_patches] \ .view(num_patch_height, num_patch_width, height, width, -1) if "unpad" in strategy: other_patch_embeds = other_patch_embeds \ .permute(4, 0, 2, 1, 3).contiguous() \ .flatten(1, 2).flatten(2, 3) other_patch_embeds = unpad_image(other_patch_embeds, (orig_height, orig_width)) other_patch_embeds = torch.cat(( other_patch_embeds, self.image_newline[:, None, None] \ .expand(*other_patch_embeds.shape[:-1], 1) \ .to(other_patch_embeds.device), ), dim=-1) other_patch_embeds = other_patch_embeds \ .flatten(1, 2).transpose(0, 1) else: other_patch_embeds = other_patch_embeds \ .permute(0, 2, 1, 3, 4).contiguous() \ .flatten(0, 3) merged_patch_embeddings = torch.cat( (base_patch_embeds, other_patch_embeds), dim=0) else: if "unpad" in strategy: merged_patch_embeddings = torch.cat( (base_patch_embeds, self.image_newline[None] \ .to(base_patch_embeds.device) ), dim=0) else: merged_patch_embeddings = base_patch_embeds return merged_patch_embeddings raise ValueError(f"Unexpected patch merge strategy: {strategy}") def _process_image_pixels( self, inputs: LlavaNextImagePixelInputs, ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: assert self.vision_tower is not None pixel_values = inputs["data"] if isinstance(pixel_values, torch.Tensor): b, num_patches, c, h, w = pixel_values.shape stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w) stacked_image_features = self._image_pixels_to_features( self.vision_tower, stacked_pixel_values) stacked_patch_embeddings = self.multi_modal_projector( stacked_image_features) return stacked_patch_embeddings.view( b, num_patches, *stacked_patch_embeddings.shape[1:]) num_patches_per_batch = [v.shape[0] for v in pixel_values] stacked_pixel_values = torch.cat(pixel_values) stacked_image_features = self._image_pixels_to_features( self.vision_tower, stacked_pixel_values) return torch.split(self.multi_modal_projector(stacked_image_features), num_patches_per_batch) def _process_image_input( self, image_input: LlavaNextImageInputs, ) -> Union[torch.Tensor, List[torch.Tensor]]: if image_input["type"] == "image_embeds": return [image_input["data"]] patch_embeddings = self._process_image_pixels(image_input) image_sizes = image_input.get("image_sizes") if image_sizes is None: batch_size = len(image_input["data"]) vision_config = self.config.vision_config default_height = default_width = vision_config.image_size image_sizes = torch.as_tensor([[default_height, default_width] for _ in range(batch_size)]) return [ self._merge_image_patch_embeddings(image_sizes[i], patch_features_batch, strategy="spatial_unpad") for i, patch_features_batch in enumerate(patch_embeddings) ] def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return None vision_embeddings = self._process_image_input(image_input) return vision_embeddings def get_input_embeddings( self, input_ids: torch.Tensor, multimodal_embeddings: Optional[NestedTensors] = None, ) -> torch.Tensor: if multimodal_embeddings is None: return self.language_model.get_input_embeddings(input_ids) inputs_embeds = embed_multimodal( input_ids, self.config.image_token_index, self.language_model.model.get_input_embeddings, multimodal_embeddings, ) return inputs_embeds def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> Union[torch.Tensor, IntermediateTensors]: """Run forward pass for LlaVA-NeXT. One key thing to understand is the `input_ids` already accounts for the positions of the to-be-inserted image embeddings. Concretely, consider a text prompt: `"A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: \\nWhat is shown in this image? ASSISTANT:"`. Tokenizer outputs: `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255, 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568, 6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901, 29871, 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799, 9047, 13566, 29901]`. To reserve space in KV cache, we have to insert placeholder tokens before they are inputted to the model, so the input processor prepends additional image tokens (denoted as `32000`), resulting in: `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255, 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568, 6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901, 29871, 32000, ..., 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799, 9047, 13566, 29901]`. Unlike in LLaVA-1.5, the number of image tokens inputted to the language model depends on the original size of the input image. Including the original image token in the input, the required number of image tokens is given by :func:`get_llava_next_image_feature_size`. This way, the `positions` and `attn_metadata` are consistent with the `input_ids`. Args: input_ids: Flattened (concatenated) input_ids corresponding to a batch. pixel_values: The pixels in each grid patch for each input image. image_sizes: The original `(height, width)` for each input image. See also: :class:`LlavaNextImageInputs` """ if intermediate_tensors is not None: inputs_embeds = None # NOTE: In v1, inputs_embeds is always generated at model runner, this # condition is for v0 compatibility. elif inputs_embeds is None: vision_embeddings = self.get_multimodal_embeddings(**kwargs) inputs_embeds = self.get_input_embeddings(input_ids, vision_embeddings) input_ids = None hidden_states = self.language_model.model(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors, inputs_embeds=inputs_embeds) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states, sampling_metadata) def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: return self.language_model.sample(logits, sampling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights)