From 72d30108a0fe06a1ba14d8645dda830a9ab01791 Mon Sep 17 00:00:00 2001 From: danielafrimi <45691845+danielafrimi@users.noreply.github.com> Date: Wed, 10 Sep 2025 16:10:06 +0300 Subject: [PATCH] Support for NemotronH Nano VLM (#23644) Signed-off-by: Daniel Afrimi --- tests/models/registry.py | 3 + vllm/config/__init__.py | 2 +- .../model_executor/models/nano_nemotron_vl.py | 1395 +++++++++++++++++ vllm/model_executor/models/registry.py | 1 + 4 files changed, 1400 insertions(+), 1 deletion(-) create mode 100644 vllm/model_executor/models/nano_nemotron_vl.py diff --git a/tests/models/registry.py b/tests/models/registry.py index a5e83bc11f144..c80f045d98743 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -515,6 +515,9 @@ _MULTIMODAL_EXAMPLE_MODELS = { trust_remote_code=True), "Llama_Nemotron_Nano_VL" : _HfExamplesInfo("nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1", # noqa: E501 trust_remote_code=True), + "NemotronH_Nano_VL": _HfExamplesInfo("nano_vl_dummy", + is_available_online=False, + trust_remote_code=True), "Ovis": _HfExamplesInfo("AIDC-AI/Ovis2-1B", trust_remote_code=True, max_transformers_version="4.53", transformers_version_reason="HF model is not compatible", # noqa: E501 diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 7422527a6854b..4bab06a98cb21 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -1552,7 +1552,7 @@ class ModelConfig: for bc in block_configs[start:end]) else: # Hybrid model Jamba - layers_block_type_value = getattr(self.hf_config, + layers_block_type_value = getattr(self.hf_text_config, "layers_block_type", None) if layers_block_type_value is not None: if hasattr(self.hf_text_config, diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py new file mode 100644 index 0000000000000..21765a483b8e0 --- /dev/null +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -0,0 +1,1395 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# -------------------------------------------------------- +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/internvl.py +# under Apache-2.0 License +# LICENSE is in root directory. +# -------------------------------------------------------- + +import copy +import warnings +from abc import ABC, abstractmethod +from collections.abc import Iterable, Mapping, Sequence +from typing import Annotated, Any, Literal, Optional, TypedDict, TypeVar, Union + +import numpy.typing as npt +import torch +import torch.nn as nn +import torchvision.transforms as T +from PIL import Image +from transformers import (AutoModel, BatchEncoding, BatchFeature, + PretrainedConfig, TensorType) + +from vllm.config import VllmConfig +from vllm.model_executor.layers.activation import ReLUSquaredActivation +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, + MultiModalEmbeddings, + SupportsMultiModal) +from vllm.model_executor.models.internvl import (calculate_internvl_targets, + get_internvl_target_ratios) +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM +from vllm.model_executor.models.utils import (flatten_bn, + init_vllm_registered_model, + maybe_prefix, + merge_multimodal_embeddings) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs, MultiModalKwargsItems, + NestedTensors) +from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, + ImageSize, MultiModalDataItems) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils.tensor_schema import TensorSchema, TensorShape + +# Configure PIL to handle large images without warnings +# This prevents DecompressionBombWarning for legitimate large images +Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely +# Alternative: Set a specific higher limit +# Image.MAX_IMAGE_PIXELS = 300000000 # ~300M pixels + +IMG_START = "" +IMG_END = "" +IMG_CONTEXT = "" + +# Profiling +MAX_FRAMES = 16 + + +class NanoNemotronVLImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + pixel_values_flat: torch.Tensor + """ + Shape: + `(batch_size * num_images * (1 + num_patches), num_channels, height, width)` + """ + + num_patches: torch.Tensor + """Shape: `(batch_size * num_images)`""" + + +class NanoNemotronVLImageEmbeddinInputs(TypedDict): + type: Literal["image_embeds"] + data: Union[torch.Tensor, list[torch.Tensor]] + """ + A tensor of shape `(num_images, total_image_feature_size, hidden_size)` + or a list of tensors of shape `(total_image_feature_size, hidden_size)` + + `hidden_size` must match the hidden size of language model backbone. + """ + + +NanoNemotronVLImageInputs = Union[NanoNemotronVLImagePixelInputs, + NanoNemotronVLImageEmbeddinInputs] + + +class NanoNemotronVLVideoPixelInputs(TensorSchema): + """ + Dimensions: + - bvf: Batch size * number of videos * num_frames + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height of each video frame + - w: Width of each video frame + """ + type: Literal["pixel_values_videos"] + pixel_values_flat: Annotated[torch.Tensor, TensorShape("bvf", 3, "h", "w")] + num_patches: Annotated[torch.Tensor, TensorShape("bn")] + + +class NanoNemotronVLVideoEmbeddingInputs(TensorSchema): + """ + Dimensions: + - n: Number of videos + - f: Total video feature size + - h: Hidden size (must match the hidden size of language model backbone) + """ + type: Literal["video_embeds"] + data: Annotated[Union[torch.Tensor, list[torch.Tensor]], + TensorShape("n", "f", "h")] + + +NanoNemotronVLVideoInputs = Union[NanoNemotronVLVideoPixelInputs, + NanoNemotronVLVideoEmbeddingInputs] + + +def input_conditioner(x, norm_mean, norm_std): + y = (x - norm_mean) / norm_std + return y + + +def dynamic_preprocess(image, + *, + image_size=512, + max_num_tiles=12, + use_thumbnail=True, + idx=0): + orig_width, orig_height = image.size + + target_ratios = get_internvl_target_ratios(1, max_num_tiles) + + blocks, target_width, target_height = calculate_internvl_targets( + orig_width=orig_width, + orig_height=orig_height, + target_ratios=target_ratios, + image_size=image_size, + use_thumbnail=False) + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size, + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + + processed_images = [ + img.convert("RGB") if img.mode != "RGB" else img + for img in processed_images + ] + processed_images = [ + T.Resize((image_size, image_size), + interpolation=T.InterpolationMode.BICUBIC)(img) + for img in processed_images + ] + processed_images = [T.ToTensor()(img) for img in processed_images] + return processed_images + + +def image_to_pixel_values( + image: Image.Image, + *, + input_size: int, + max_num: int, + use_thumbnail: bool, + idx: int, +) -> torch.Tensor: + images = dynamic_preprocess( + image, + image_size=input_size, + max_num_tiles=max_num, + use_thumbnail=use_thumbnail, + idx=idx, + ) + + pixel_values = torch.stack(images) + return pixel_values + + +def video_to_pixel_values( + video: npt.NDArray, + *, + input_size: int, + max_num_tiles: int = 1, + use_thumbnail: bool, +) -> torch.Tensor: + # Convert each frame to a single resized tile tensor consistent + # with image path + frames_tensors: list[torch.Tensor] = [] + for frame in video: + pil_frame = dynamic_preprocess( + Image.fromarray(frame, mode="RGB"), + image_size=input_size, + max_num_tiles=max_num_tiles, + use_thumbnail=use_thumbnail, + idx=0, + ) + # dynamic_preprocess returns tensors already; take the single tile + assert len(pil_frame) >= 1 + frames_tensors.append(pil_frame[0]) + + return torch.stack(frames_tensors) + + +class BaseNanoNemotronVLProcessor(ABC): + """ + This model doesn't define its own HF processor, + so we implement our own one here. + + The code to insert image tokens is based on: + https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py#L252 + """ + + def __init__(self, config: PretrainedConfig, tokenizer: AnyTokenizer, + *args, **kwargs) -> None: + super().__init__() + + self.config = config + self.tokenizer = tokenizer + + image_size: int = config.force_image_size + patch_size: int = config.patch_size + + self.num_image_token = int( + (image_size // patch_size)**2 * (config.downsample_ratio**2)) + self.image_size = image_size + self.use_thumbnail: bool = config.use_thumbnail + self.norm_mean = torch.Tensor(config.norm_mean).reshape(1, 3, 1, 1) + self.norm_std = torch.Tensor(config.norm_std).reshape(1, 3, 1, 1) + + @property + @abstractmethod + def image_token_id(self) -> int: + raise NotImplementedError + + @abstractmethod + def get_image_repl( + self, + feature_size: int, + num_patches: Optional[int], + ) -> PromptUpdateDetails[str]: + raise NotImplementedError + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + max_num_tiles: int, + ) -> int: + target_ratios = get_internvl_target_ratios(1, max_num_tiles) + + num_patches, _, _ = calculate_internvl_targets( + orig_width=image_width, + orig_height=image_height, + target_ratios=target_ratios, + image_size=self.image_size, + use_thumbnail=self.use_thumbnail, + ) + + return num_patches * self.num_image_token + + def _images_to_pixel_values_lst( + self, + images: list[Image.Image], + max_num_tiles: int, + ) -> list[torch.Tensor]: + return [ + image_to_pixel_values( + image, + input_size=self.image_size, + max_num=max_num_tiles, + use_thumbnail=self.use_thumbnail, + idx=idx, + ) for idx, image in enumerate(images) + ] + + def _preprocess_image( + self, + text: list[str], + images: list[Image.Image], + max_num_tiles: int, + ) -> tuple[list[str], dict[str, torch.Tensor]]: + if len(images) == 0: + image_inputs = {} + else: + pixel_values_lst = self._images_to_pixel_values_lst( + images, max_num_tiles) + image_inputs: dict[str, NestedTensors] = { + "pixel_values_flat": + input_conditioner(torch.cat(pixel_values_lst), self.norm_mean, + self.norm_std), + "image_num_patches": + torch.tensor([len(item) for item in pixel_values_lst]), + } + + for pixel_values in pixel_values_lst: + num_patches = pixel_values.shape[0] + feature_size = num_patches * self.num_image_token + image_repl = self.get_image_repl(feature_size, num_patches) + text = [t.replace('', image_repl.full, 1) for t in text] + return text, image_inputs + + def _make_batch_input(self, + input_item: Optional[Union[Any, list[Any]]] = None): + if input_item is None: + input_item = [] + if not isinstance(input_item, list): + input_item = [input_item] + return input_item + + def __call__( + self, + text: Optional[Union[str, list[str]]] = None, + images: Optional[Union[Image.Image, list[Image.Image]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + max_num_tiles: Optional[int] = None, + ) -> Mapping[str, NestedTensors]: + # Use default if not provided + if max_num_tiles is None: + max_num_tiles = 12 + + text, images = [self._make_batch_input(x) for x in (text, images)] + + text, image_inputs = self._preprocess_image( + text=text, + images=images, + max_num_tiles=max_num_tiles, + ) + + text_inputs = self.tokenizer(text, add_special_tokens=False) + + return { + **BatchEncoding(text_inputs, tensor_type=return_tensors), + **image_inputs, + } + + +class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): + """ + HF Processor with extended video processing logic. + Code for video processing is adapted from video example: + https://huggingface.co/OpenGVLab/InternVL3-1B#inference-with-transformers + """ + + def __init__( + self, + config: PretrainedConfig, + tokenizer: AnyTokenizer, + *, + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + video_token: Optional[str] = None, + ) -> None: + super().__init__( + config=config, + tokenizer=tokenizer, + min_dynamic_patch=min_dynamic_patch, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + ) + # add extra video token for video processing + self.video_token = video_token + + @property + def supports_video(self) -> bool: + return self.video_token_id is not None + + @property + def video_token_id(self) -> Optional[int]: + if self.video_token is None: + return None + return self.tokenizer.get_vocab().get(self.video_token, None) + + @property + def image_token_id(self) -> int: + return self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT) + + def _videos_to_pixel_values_lst( + self, + videos: list[npt.NDArray], + max_num_tiles: int, + dynamic_image_size: Optional[bool] = None, + ) -> list[torch.Tensor]: + + return [ + video_to_pixel_values( + video, + input_size=self.image_size, + max_num_tiles=max_num_tiles, + use_thumbnail=self.use_thumbnail, + ) for video in videos + ] + + def _preprocess_video( + self, + text: list[str], + videos: list[npt.NDArray], + max_num_tiles: int, + dynamic_image_size: Optional[bool] = None, + ): + if len(videos) == 0 or not self.supports_video: + video_inputs = {} + else: + pixel_values_lst_video = self._videos_to_pixel_values_lst( + videos, + max_num_tiles=max_num_tiles, + dynamic_image_size=dynamic_image_size, + ) + + video_inputs: dict[str, NestedTensors] = { + "pixel_values_flat_video": + input_conditioner(torch.cat(pixel_values_lst_video), + self.norm_mean, self.norm_std), + "video_num_patches": + torch.tensor([len(item) for item in pixel_values_lst_video]), + } + + for pixel_values in pixel_values_lst_video: + num_patches = pixel_values.shape[0] + + video_repl = self.get_video_repl(self.num_image_token, + num_patches, self.video_token) + text = [t.replace('