# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # -------------------------------------------------------- # Adapted from # https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/internvl.py # under Apache-2.0 License # LICENSE is in root directory. # -------------------------------------------------------- import copy import warnings from abc import ABC, abstractmethod from collections.abc import Iterable, Mapping, Sequence from typing import Annotated, Any, Literal, TypeAlias, TypeVar import numpy.typing as npt import regex as re import torch import torch.nn as nn import torchvision.transforms as T from PIL import Image from transformers import BatchFeature, PretrainedConfig, TensorType from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions from vllm.model_executor.layers.activation import ReLUSquaredActivation from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import ( HasInnerState, IsHybrid, MultiModalEmbeddings, SupportsMultiModal, SupportsMultiModalPruning, ) from vllm.model_executor.models.internvl import ( calculate_internvl_targets, get_internvl_target_ratios, ) from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM from vllm.model_executor.models.radio import RadioModel from vllm.model_executor.models.utils import ( init_vllm_registered_model, maybe_prefix, ) from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.evs import ( compute_retained_tokens_count, compute_retention_mask, ) from vllm.multimodal.inputs import ( MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, VideoItem, ) from vllm.multimodal.parse import ( ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems, MultiModalDataParser, ) from vllm.multimodal.processing import ( BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails, _seq2tokens, ) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config from vllm.transformers_utils.configs.radio import RadioConfig from vllm.utils.tensor_schema import TensorSchema, TensorShape from .utils import _merge_multimodal_embeddings # Configure PIL to handle large images without warnings # This prevents DecompressionBombWarning for legitimate large images Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely # Alternative: Set a specific higher limit # Image.MAX_IMAGE_PIXELS = 300000000 # ~300M pixels IMG_START = "" IMG_END = "" IMG_CONTEXT = "" # Profiling # MAX_FRAMES = 16 DEFAULT_NUM_TILES = 12 class NanoNemotronVLImagePixelInputs(TensorSchema): """ Dimensions: - bn: Batch size * number of images - bnp: Batch size * number of images * (1 + num_patches) - c: Number of channels (3) - h: Height of each image patch - w: Width of each image patch """ type: Literal["pixel_values"] pixel_values_flat: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")] num_patches: Annotated[torch.Tensor, TensorShape("bn")] class NanoNemotronVLImageEmbeddingInputs(TensorSchema): """ Dimensions: - n: Number of images - f: Total image feature size - h: Hidden size (must match the hidden size of language model backbone) """ type: Literal["image_embeds"] data: Annotated[torch.Tensor | list[torch.Tensor], TensorShape("n", "f", "h")] NanoNemotronVLImageInputs: TypeAlias = ( NanoNemotronVLImagePixelInputs | NanoNemotronVLImageEmbeddingInputs ) class NanoNemotronVLVideoPixelInputs(TensorSchema): """ Dimensions: - bvf: Batch size * number of videos * num_frames - bn: Batch size * number of videos - f: Number of frames - c: Number of channels (3) - h: Height of each video frame - w: Width of each video frame """ type: Literal["pixel_values_videos"] pixel_values_flat: Annotated[torch.Tensor, TensorShape("bvf", 3, "h", "w")] num_patches: Annotated[torch.Tensor, TensorShape("bn")] frames_indices: Annotated[torch.Tensor, TensorShape("bvf")] frame_duration_ms: Annotated[torch.Tensor, TensorShape("bn")] class NanoNemotronVLVideoEmbeddingInputs(TensorSchema): """ Dimensions: - n: Number of videos - f: Total video feature size - h: Hidden size (must match the hidden size of language model backbone) """ type: Literal["video_embeds"] data: Annotated[torch.Tensor | list[torch.Tensor], TensorShape("n", "f", "h")] NanoNemotronVLVideoInputs: TypeAlias = ( NanoNemotronVLVideoPixelInputs | NanoNemotronVLVideoEmbeddingInputs ) def dynamic_preprocess( image, *, image_size=512, max_num_tiles=12, use_thumbnail=True, idx=0 ): orig_width, orig_height = image.size target_ratios = get_internvl_target_ratios(1, max_num_tiles) blocks, target_width, target_height = calculate_internvl_targets( orig_width=orig_width, orig_height=orig_height, target_ratios=target_ratios, image_size=image_size, use_thumbnail=False, ) # resize the image resized_img = image.resize((target_width, target_height)) processed_images = [] for i in range(blocks): box = ( (i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size, ((i % (target_width // image_size)) + 1) * image_size, ((i // (target_width // image_size)) + 1) * image_size, ) # split the image split_img = resized_img.crop(box) processed_images.append(split_img) assert len(processed_images) == blocks if use_thumbnail and len(processed_images) != 1: thumbnail_img = image.resize((image_size, image_size)) processed_images.append(thumbnail_img) processed_images = [ img.convert("RGB") if img.mode != "RGB" else img for img in processed_images ] processed_images = [ T.Resize((image_size, image_size), interpolation=T.InterpolationMode.BICUBIC)( img ) for img in processed_images ] processed_images = [T.ToTensor()(img) for img in processed_images] return processed_images def image_to_pixel_values( image: Image.Image, *, input_size: int, max_num: int, use_thumbnail: bool, idx: int, ) -> torch.Tensor: images = dynamic_preprocess( image, image_size=input_size, max_num_tiles=max_num, use_thumbnail=use_thumbnail, idx=idx, ) pixel_values = torch.stack(images) return pixel_values def video_to_pixel_values( video: npt.NDArray, *, input_size: int, max_num_tiles: int = 1, use_thumbnail: bool, ) -> torch.Tensor: assert max_num_tiles == 1, "Video modality always uses one tile" # Convert each frame to a single resized tile tensor consistent # with image path frames_tensors: list[torch.Tensor] = [] for frame in video: pil_frame = dynamic_preprocess( Image.fromarray(frame, mode="RGB"), image_size=input_size, max_num_tiles=max_num_tiles, use_thumbnail=use_thumbnail, idx=0, ) # dynamic_preprocess returns tensors already; take the single tile assert len(pil_frame) >= 1 frames_tensors.append(pil_frame[-1]) return torch.stack(frames_tensors) def input_conditioner(x, norm_mean, norm_std): return (x - norm_mean) / norm_std def calculate_timestamps( indices: list[int] | torch.Tensor, frame_duration_ms: int, ): if not isinstance(indices, list): indices = indices.tolist() timestamps = [int(i) * frame_duration_ms / 1000.0 for i in indices] return timestamps class BaseNanoNemotronVLProcessor(ABC): """ This model doesn't define its own HF processor, so we implement our own one here. The code to insert image tokens is based on: https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py#L252 """ def __init__( self, config: PretrainedConfig, tokenizer: TokenizerLike, *args, max_num_tiles: int | None = None, **kwargs, ) -> None: super().__init__() self.config = config self.tokenizer = tokenizer self.max_num_tiles = max_num_tiles or DEFAULT_NUM_TILES image_size: int = config.force_image_size patch_size: int = config.patch_size self.num_image_token = int( (image_size // patch_size) ** 2 * (config.downsample_ratio**2) ) self.image_size = image_size self.use_thumbnail: bool = config.use_thumbnail self.norm_mean = torch.Tensor(config.norm_mean).reshape(1, 3, 1, 1) self.norm_std = torch.Tensor(config.norm_std).reshape(1, 3, 1, 1) @property @abstractmethod def image_token_id(self) -> int: raise NotImplementedError @abstractmethod def get_image_repl( self, feature_size: int, num_patches: int | None, ) -> PromptUpdateDetails[str]: raise NotImplementedError def get_num_image_tokens( self, *, image_width: int, image_height: int, max_num_tiles: int, ) -> int: target_ratios = get_internvl_target_ratios(1, max_num_tiles) num_patches, _, _ = calculate_internvl_targets( orig_width=image_width, orig_height=image_height, target_ratios=target_ratios, image_size=self.image_size, use_thumbnail=self.use_thumbnail, ) return num_patches * self.num_image_token def _images_to_pixel_values_lst( self, images: list[Image.Image], max_num_tiles: int, ) -> list[torch.Tensor]: return [ image_to_pixel_values( image, input_size=self.image_size, max_num=max_num_tiles, use_thumbnail=self.use_thumbnail, idx=idx, ) for idx, image in enumerate(images) ] def _preprocess_image( self, text: list[str], images: list[Image.Image], max_num_tiles: int, ) -> tuple[list[str], dict[str, torch.Tensor]]: if len(images) == 0: image_inputs = {} else: pixel_values_lst = self._images_to_pixel_values_lst(images, max_num_tiles) image_inputs = { "pixel_values_flat": input_conditioner( torch.cat(pixel_values_lst), self.norm_mean, self.norm_std ), "image_num_patches": torch.tensor( [len(item) for item in pixel_values_lst] ), } assert len(text) == 1, ( "hf_processor is called on the output of get_dummy_text, " "which should be a single string" ) parts = [x for x in re.split(r"()", text[0]) if x] assert parts.count("") == len(pixel_values_lst), ( "the number of tokens in the text should be the " "same as the number of images" ) for i, pixel_values in enumerate(pixel_values_lst): num_patches = pixel_values.shape[0] feature_size = num_patches * self.num_image_token image_repl = self.get_image_repl(feature_size, num_patches) parts[i] = parts[i].replace("", image_repl.full) text = ["".join(parts)] return text, image_inputs def _make_batch_input(self, input_item: Any | list[Any] | None = None): if input_item is None: input_item = [] if not isinstance(input_item, list): input_item = [input_item] return input_item def __call__( self, text: str | list[str] | None = None, images: Image.Image | list[Image.Image] | None = None, return_tensors: str | TensorType | None = None, max_num_tiles: int | None = None, ) -> BatchFeature: # Use default if not provided if max_num_tiles is None: max_num_tiles = self.max_num_tiles text, images = [self._make_batch_input(x) for x in (text, images)] text, image_inputs = self._preprocess_image( text=text, images=images, max_num_tiles=max_num_tiles, ) text_inputs = self.tokenizer(text, add_special_tokens=False) combined_outputs = {**text_inputs, **image_inputs} return BatchFeature(combined_outputs, tensor_type=return_tensors) class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): """ HF Processor with extended video processing logic. Code for video processing is adapted from video example: https://huggingface.co/OpenGVLab/InternVL3-1B#inference-with-transformers """ def __init__( self, config: PretrainedConfig, tokenizer: TokenizerLike, *, max_num_tiles: int | None = None, min_dynamic_patch: int | None = None, max_dynamic_patch: int | None = None, dynamic_image_size: bool | None = None, video_token: str | None = None, video_pruning_rate: float | None = None, ) -> None: super().__init__( config=config, tokenizer=tokenizer, max_num_tiles=max_num_tiles, min_dynamic_patch=min_dynamic_patch, max_dynamic_patch=max_dynamic_patch, dynamic_image_size=dynamic_image_size, ) # add extra video token for video processing self.video_token = video_token self.video_pruning_rate = video_pruning_rate # Pre-tokenize special tokens for video processing # to avoid repeated tokenization self._img_start_token_ids = tokenizer.encode( IMG_START, add_special_tokens=False ) self._img_end_token_ids = tokenizer.encode(IMG_END, add_special_tokens=False) self._img_context_token_ids = tokenizer.encode( IMG_CONTEXT, add_special_tokens=False ) @property def supports_video(self) -> bool: return self.video_token_id is not None @property def video_token_id(self) -> int | None: if self.video_token is None: return None return self.tokenizer.get_vocab().get(self.video_token, None) @property def image_token_id(self) -> int: return self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT) def _videos_to_pixel_values_lst( self, videos: list[npt.NDArray], max_num_tiles: int, dynamic_image_size: bool | None = None, ) -> list[torch.Tensor]: return [ video_to_pixel_values( video, input_size=self.image_size, max_num_tiles=max_num_tiles, use_thumbnail=self.use_thumbnail, ) for video in videos ] def _preprocess_video( self, text: list[str], videos: list[tuple[npt.NDArray, dict[str, Any]]], max_num_tiles: int, dynamic_image_size: bool | None = None, ): if len(videos) == 0 or not self.supports_video: video_inputs = {} else: videos_lst = [v[0] for v in videos] video_metadata_lst = [v[1] for v in videos] pixel_values_lst_video = self._videos_to_pixel_values_lst( videos_lst, max_num_tiles=max_num_tiles, dynamic_image_size=dynamic_image_size, ) # We use frame duration in milliseconds (as integer) to ensure # we have consistent timestamps calculation. At preprocessing # fps parameter is given in fp32, while at inference it is bf16 # which leads to inaccurate timestamp calculation and causes # timestamp values to differ.In rare cases this causes # mismatching number of output tokens for tokenized frame prefixes frame_duration_ms_lst = [ int(1000.0 / metadata["fps"]) for metadata in video_metadata_lst ] frames_indices_lst = [ metadata["frames_indices"] for metadata in video_metadata_lst ] video_inputs = { "pixel_values_flat_video": input_conditioner( torch.cat(pixel_values_lst_video), self.norm_mean, self.norm_std ), "video_num_patches": torch.tensor( [len(item) for item in pixel_values_lst_video] ), "frames_indices": frames_indices_lst, "frame_duration_ms": torch.tensor(frame_duration_ms_lst), } image_size: int = self.config.force_image_size patch_size: int = self.config.patch_size downsample_ratio = self.config.downsample_ratio tokens_in_single_frame = int( (image_size * image_size // patch_size**2) * (downsample_ratio**2) ) for pixel_values, video_metadata, frames_indices, frame_duration_ms in zip( pixel_values_lst_video, video_metadata_lst, frames_indices_lst, frame_duration_ms_lst, ): num_frames = pixel_values.shape[0] if ( self.video_pruning_rate is not None and self.video_pruning_rate > 0.0 ): # Start of EVS-specific code num_tokens = compute_retained_tokens_count( tokens_per_frame=tokens_in_single_frame, num_frames=num_frames, q=self.video_pruning_rate, ) # Here we just need placeholders that won't actually be replaced - # we just need to make sure the total number of tokens is correct # assign all tokens to the first frame tokens_per_frame = [num_tokens] + [0] * (num_frames - 1) # End of EVS-specific code else: tokens_per_frame = [tokens_in_single_frame] * num_frames video_repl = self.get_video_repl( tokens_per_frame=tokens_per_frame, frames_indices=frames_indices, frame_duration_ms=frame_duration_ms, tokenizer=self.tokenizer, img_start_token_ids=self._img_start_token_ids, img_end_token_ids=self._img_end_token_ids, img_context_token_ids=self._img_context_token_ids, ) # video_repl.full is a list of token IDs # Convert token IDs back to text for the HF processor flow video_repl_text = self.tokenizer.decode( video_repl.full, skip_special_tokens=False ) text = [t.replace("