diff --git a/tests/models/registry.py b/tests/models/registry.py
index a5e83bc11f144..c80f045d98743 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -515,6 +515,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
trust_remote_code=True),
"Llama_Nemotron_Nano_VL" : _HfExamplesInfo("nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1", # noqa: E501
trust_remote_code=True),
+ "NemotronH_Nano_VL": _HfExamplesInfo("nano_vl_dummy",
+ is_available_online=False,
+ trust_remote_code=True),
"Ovis": _HfExamplesInfo("AIDC-AI/Ovis2-1B", trust_remote_code=True,
max_transformers_version="4.53",
transformers_version_reason="HF model is not compatible", # noqa: E501
diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py
index 7422527a6854b..4bab06a98cb21 100644
--- a/vllm/config/__init__.py
+++ b/vllm/config/__init__.py
@@ -1552,7 +1552,7 @@ class ModelConfig:
for bc in block_configs[start:end])
else:
# Hybrid model Jamba
- layers_block_type_value = getattr(self.hf_config,
+ layers_block_type_value = getattr(self.hf_text_config,
"layers_block_type", None)
if layers_block_type_value is not None:
if hasattr(self.hf_text_config,
diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py
new file mode 100644
index 0000000000000..21765a483b8e0
--- /dev/null
+++ b/vllm/model_executor/models/nano_nemotron_vl.py
@@ -0,0 +1,1395 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+# --------------------------------------------------------
+# Adapted from
+# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/internvl.py
+# under Apache-2.0 License
+# LICENSE is in root directory.
+# --------------------------------------------------------
+
+import copy
+import warnings
+from abc import ABC, abstractmethod
+from collections.abc import Iterable, Mapping, Sequence
+from typing import Annotated, Any, Literal, Optional, TypedDict, TypeVar, Union
+
+import numpy.typing as npt
+import torch
+import torch.nn as nn
+import torchvision.transforms as T
+from PIL import Image
+from transformers import (AutoModel, BatchEncoding, BatchFeature,
+ PretrainedConfig, TensorType)
+
+from vllm.config import VllmConfig
+from vllm.model_executor.layers.activation import ReLUSquaredActivation
+from vllm.model_executor.layers.layernorm import RMSNorm
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
+ MultiModalEmbeddings,
+ SupportsMultiModal)
+from vllm.model_executor.models.internvl import (calculate_internvl_targets,
+ get_internvl_target_ratios)
+from vllm.model_executor.models.module_mapping import MultiModelKeys
+from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM
+from vllm.model_executor.models.utils import (flatten_bn,
+ init_vllm_registered_model,
+ maybe_prefix,
+ merge_multimodal_embeddings)
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
+ MultiModalKwargs, MultiModalKwargsItems,
+ NestedTensors)
+from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
+ ImageSize, MultiModalDataItems)
+from vllm.multimodal.processing import (BaseMultiModalProcessor,
+ BaseProcessingInfo, PromptReplacement,
+ PromptUpdate, PromptUpdateDetails)
+from vllm.multimodal.profiling import BaseDummyInputsBuilder
+from vllm.sequence import IntermediateTensors
+from vllm.transformers_utils.tokenizer import AnyTokenizer
+from vllm.utils.tensor_schema import TensorSchema, TensorShape
+
+# Configure PIL to handle large images without warnings
+# This prevents DecompressionBombWarning for legitimate large images
+Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely
+# Alternative: Set a specific higher limit
+# Image.MAX_IMAGE_PIXELS = 300000000 # ~300M pixels
+
+IMG_START = "
"
+IMG_END = ""
+IMG_CONTEXT = ""
+
+# Profiling
+MAX_FRAMES = 16
+
+
+class NanoNemotronVLImagePixelInputs(TypedDict):
+ type: Literal["pixel_values"]
+ pixel_values_flat: torch.Tensor
+ """
+ Shape:
+ `(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
+ """
+
+ num_patches: torch.Tensor
+ """Shape: `(batch_size * num_images)`"""
+
+
+class NanoNemotronVLImageEmbeddinInputs(TypedDict):
+ type: Literal["image_embeds"]
+ data: Union[torch.Tensor, list[torch.Tensor]]
+ """
+ A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
+ or a list of tensors of shape `(total_image_feature_size, hidden_size)`
+
+ `hidden_size` must match the hidden size of language model backbone.
+ """
+
+
+NanoNemotronVLImageInputs = Union[NanoNemotronVLImagePixelInputs,
+ NanoNemotronVLImageEmbeddinInputs]
+
+
+class NanoNemotronVLVideoPixelInputs(TensorSchema):
+ """
+ Dimensions:
+ - bvf: Batch size * number of videos * num_frames
+ - bn: Batch size * number of images
+ - c: Number of channels (3)
+ - h: Height of each video frame
+ - w: Width of each video frame
+ """
+ type: Literal["pixel_values_videos"]
+ pixel_values_flat: Annotated[torch.Tensor, TensorShape("bvf", 3, "h", "w")]
+ num_patches: Annotated[torch.Tensor, TensorShape("bn")]
+
+
+class NanoNemotronVLVideoEmbeddingInputs(TensorSchema):
+ """
+ Dimensions:
+ - n: Number of videos
+ - f: Total video feature size
+ - h: Hidden size (must match the hidden size of language model backbone)
+ """
+ type: Literal["video_embeds"]
+ data: Annotated[Union[torch.Tensor, list[torch.Tensor]],
+ TensorShape("n", "f", "h")]
+
+
+NanoNemotronVLVideoInputs = Union[NanoNemotronVLVideoPixelInputs,
+ NanoNemotronVLVideoEmbeddingInputs]
+
+
+def input_conditioner(x, norm_mean, norm_std):
+ y = (x - norm_mean) / norm_std
+ return y
+
+
+def dynamic_preprocess(image,
+ *,
+ image_size=512,
+ max_num_tiles=12,
+ use_thumbnail=True,
+ idx=0):
+ orig_width, orig_height = image.size
+
+ target_ratios = get_internvl_target_ratios(1, max_num_tiles)
+
+ blocks, target_width, target_height = calculate_internvl_targets(
+ orig_width=orig_width,
+ orig_height=orig_height,
+ target_ratios=target_ratios,
+ image_size=image_size,
+ use_thumbnail=False)
+ # resize the image
+ resized_img = image.resize((target_width, target_height))
+ processed_images = []
+ for i in range(blocks):
+ box = (
+ (i % (target_width // image_size)) * image_size,
+ (i // (target_width // image_size)) * image_size,
+ ((i % (target_width // image_size)) + 1) * image_size,
+ ((i // (target_width // image_size)) + 1) * image_size,
+ )
+ # split the image
+ split_img = resized_img.crop(box)
+ processed_images.append(split_img)
+ assert len(processed_images) == blocks
+ if use_thumbnail and len(processed_images) != 1:
+ thumbnail_img = image.resize((image_size, image_size))
+ processed_images.append(thumbnail_img)
+
+ processed_images = [
+ img.convert("RGB") if img.mode != "RGB" else img
+ for img in processed_images
+ ]
+ processed_images = [
+ T.Resize((image_size, image_size),
+ interpolation=T.InterpolationMode.BICUBIC)(img)
+ for img in processed_images
+ ]
+ processed_images = [T.ToTensor()(img) for img in processed_images]
+ return processed_images
+
+
+def image_to_pixel_values(
+ image: Image.Image,
+ *,
+ input_size: int,
+ max_num: int,
+ use_thumbnail: bool,
+ idx: int,
+) -> torch.Tensor:
+ images = dynamic_preprocess(
+ image,
+ image_size=input_size,
+ max_num_tiles=max_num,
+ use_thumbnail=use_thumbnail,
+ idx=idx,
+ )
+
+ pixel_values = torch.stack(images)
+ return pixel_values
+
+
+def video_to_pixel_values(
+ video: npt.NDArray,
+ *,
+ input_size: int,
+ max_num_tiles: int = 1,
+ use_thumbnail: bool,
+) -> torch.Tensor:
+ # Convert each frame to a single resized tile tensor consistent
+ # with image path
+ frames_tensors: list[torch.Tensor] = []
+ for frame in video:
+ pil_frame = dynamic_preprocess(
+ Image.fromarray(frame, mode="RGB"),
+ image_size=input_size,
+ max_num_tiles=max_num_tiles,
+ use_thumbnail=use_thumbnail,
+ idx=0,
+ )
+ # dynamic_preprocess returns tensors already; take the single tile
+ assert len(pil_frame) >= 1
+ frames_tensors.append(pil_frame[0])
+
+ return torch.stack(frames_tensors)
+
+
+class BaseNanoNemotronVLProcessor(ABC):
+ """
+ This model doesn't define its own HF processor,
+ so we implement our own one here.
+
+ The code to insert image tokens is based on:
+ https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py#L252
+ """
+
+ def __init__(self, config: PretrainedConfig, tokenizer: AnyTokenizer,
+ *args, **kwargs) -> None:
+ super().__init__()
+
+ self.config = config
+ self.tokenizer = tokenizer
+
+ image_size: int = config.force_image_size
+ patch_size: int = config.patch_size
+
+ self.num_image_token = int(
+ (image_size // patch_size)**2 * (config.downsample_ratio**2))
+ self.image_size = image_size
+ self.use_thumbnail: bool = config.use_thumbnail
+ self.norm_mean = torch.Tensor(config.norm_mean).reshape(1, 3, 1, 1)
+ self.norm_std = torch.Tensor(config.norm_std).reshape(1, 3, 1, 1)
+
+ @property
+ @abstractmethod
+ def image_token_id(self) -> int:
+ raise NotImplementedError
+
+ @abstractmethod
+ def get_image_repl(
+ self,
+ feature_size: int,
+ num_patches: Optional[int],
+ ) -> PromptUpdateDetails[str]:
+ raise NotImplementedError
+
+ def get_num_image_tokens(
+ self,
+ *,
+ image_width: int,
+ image_height: int,
+ max_num_tiles: int,
+ ) -> int:
+ target_ratios = get_internvl_target_ratios(1, max_num_tiles)
+
+ num_patches, _, _ = calculate_internvl_targets(
+ orig_width=image_width,
+ orig_height=image_height,
+ target_ratios=target_ratios,
+ image_size=self.image_size,
+ use_thumbnail=self.use_thumbnail,
+ )
+
+ return num_patches * self.num_image_token
+
+ def _images_to_pixel_values_lst(
+ self,
+ images: list[Image.Image],
+ max_num_tiles: int,
+ ) -> list[torch.Tensor]:
+ return [
+ image_to_pixel_values(
+ image,
+ input_size=self.image_size,
+ max_num=max_num_tiles,
+ use_thumbnail=self.use_thumbnail,
+ idx=idx,
+ ) for idx, image in enumerate(images)
+ ]
+
+ def _preprocess_image(
+ self,
+ text: list[str],
+ images: list[Image.Image],
+ max_num_tiles: int,
+ ) -> tuple[list[str], dict[str, torch.Tensor]]:
+ if len(images) == 0:
+ image_inputs = {}
+ else:
+ pixel_values_lst = self._images_to_pixel_values_lst(
+ images, max_num_tiles)
+ image_inputs: dict[str, NestedTensors] = {
+ "pixel_values_flat":
+ input_conditioner(torch.cat(pixel_values_lst), self.norm_mean,
+ self.norm_std),
+ "image_num_patches":
+ torch.tensor([len(item) for item in pixel_values_lst]),
+ }
+
+ for pixel_values in pixel_values_lst:
+ num_patches = pixel_values.shape[0]
+ feature_size = num_patches * self.num_image_token
+ image_repl = self.get_image_repl(feature_size, num_patches)
+ text = [t.replace('', image_repl.full, 1) for t in text]
+ return text, image_inputs
+
+ def _make_batch_input(self,
+ input_item: Optional[Union[Any, list[Any]]] = None):
+ if input_item is None:
+ input_item = []
+ if not isinstance(input_item, list):
+ input_item = [input_item]
+ return input_item
+
+ def __call__(
+ self,
+ text: Optional[Union[str, list[str]]] = None,
+ images: Optional[Union[Image.Image, list[Image.Image]]] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ max_num_tiles: Optional[int] = None,
+ ) -> Mapping[str, NestedTensors]:
+ # Use default if not provided
+ if max_num_tiles is None:
+ max_num_tiles = 12
+
+ text, images = [self._make_batch_input(x) for x in (text, images)]
+
+ text, image_inputs = self._preprocess_image(
+ text=text,
+ images=images,
+ max_num_tiles=max_num_tiles,
+ )
+
+ text_inputs = self.tokenizer(text, add_special_tokens=False)
+
+ return {
+ **BatchEncoding(text_inputs, tensor_type=return_tensors),
+ **image_inputs,
+ }
+
+
+class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
+ """
+ HF Processor with extended video processing logic.
+ Code for video processing is adapted from video example:
+ https://huggingface.co/OpenGVLab/InternVL3-1B#inference-with-transformers
+ """
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ tokenizer: AnyTokenizer,
+ *,
+ min_dynamic_patch: Optional[int] = None,
+ max_dynamic_patch: Optional[int] = None,
+ dynamic_image_size: Optional[bool] = None,
+ video_token: Optional[str] = None,
+ ) -> None:
+ super().__init__(
+ config=config,
+ tokenizer=tokenizer,
+ min_dynamic_patch=min_dynamic_patch,
+ max_dynamic_patch=max_dynamic_patch,
+ dynamic_image_size=dynamic_image_size,
+ )
+ # add extra video token for video processing
+ self.video_token = video_token
+
+ @property
+ def supports_video(self) -> bool:
+ return self.video_token_id is not None
+
+ @property
+ def video_token_id(self) -> Optional[int]:
+ if self.video_token is None:
+ return None
+ return self.tokenizer.get_vocab().get(self.video_token, None)
+
+ @property
+ def image_token_id(self) -> int:
+ return self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT)
+
+ def _videos_to_pixel_values_lst(
+ self,
+ videos: list[npt.NDArray],
+ max_num_tiles: int,
+ dynamic_image_size: Optional[bool] = None,
+ ) -> list[torch.Tensor]:
+
+ return [
+ video_to_pixel_values(
+ video,
+ input_size=self.image_size,
+ max_num_tiles=max_num_tiles,
+ use_thumbnail=self.use_thumbnail,
+ ) for video in videos
+ ]
+
+ def _preprocess_video(
+ self,
+ text: list[str],
+ videos: list[npt.NDArray],
+ max_num_tiles: int,
+ dynamic_image_size: Optional[bool] = None,
+ ):
+ if len(videos) == 0 or not self.supports_video:
+ video_inputs = {}
+ else:
+ pixel_values_lst_video = self._videos_to_pixel_values_lst(
+ videos,
+ max_num_tiles=max_num_tiles,
+ dynamic_image_size=dynamic_image_size,
+ )
+
+ video_inputs: dict[str, NestedTensors] = {
+ "pixel_values_flat_video":
+ input_conditioner(torch.cat(pixel_values_lst_video),
+ self.norm_mean, self.norm_std),
+ "video_num_patches":
+ torch.tensor([len(item) for item in pixel_values_lst_video]),
+ }
+
+ for pixel_values in pixel_values_lst_video:
+ num_patches = pixel_values.shape[0]
+
+ video_repl = self.get_video_repl(self.num_image_token,
+ num_patches, self.video_token)
+ text = [t.replace('