diff --git a/docker/Dockerfile.cpu b/docker/Dockerfile.cpu
index 5da2c9467bfc6..982c1ddf27438 100644
--- a/docker/Dockerfile.cpu
+++ b/docker/Dockerfile.cpu
@@ -95,7 +95,7 @@ WORKDIR /workspace/vllm
RUN --mount=type=bind,src=requirements/test.in,target=requirements/test.in \
cp requirements/test.in requirements/cpu-test.in && \
sed -i '/mamba_ssm/d' requirements/cpu-test.in && \
- sed -i 's/torch==.*/torch==2.6.0/g' requirements/cpu-test.in && \
+ sed -i 's/^torch==.*/torch==2.6.0/g' requirements/cpu-test.in && \
sed -i 's/torchaudio.*/torchaudio/g' requirements/cpu-test.in && \
sed -i 's/torchvision.*/torchvision/g' requirements/cpu-test.in && \
uv pip compile requirements/cpu-test.in -o requirements/cpu-test.txt --index-strategy unsafe-best-match --torch-backend cpu
diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index 42afaeac0e8ee..55c6e3d7f5553 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -581,6 +581,7 @@ Specified using `--task generate`.
| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + IE+ + VE+ | `Kwai-Keye/Keye-VL-8B-Preview` | | | ✅︎ |
| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I+ | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | | ✅︎ |
| `Llama4ForConditionalGeneration` | Llama 4 | T + I+ | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | | ✅︎ | ✅︎ |
+| `Llama_Nemotron_Nano_VL` | Llama Nemotron Nano VL | T + IE+ | `nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1` | ✅︎ | ✅︎ | ✅︎ |
| `LlavaForConditionalGeneration` | LLaVA-1.5, Pixtral (HF Transformers) | T + IE+ | `llava-hf/llava-1.5-7b-hf`, `TIGER-Lab/Mantis-8B-siglip-llama3` (see note), `mistral-community/pixtral-12b`, etc. | | ✅︎ | ✅︎ |
| `LlavaNextForConditionalGeneration` | LLaVA-NeXT | T + IE+ | `llava-hf/llava-v1.6-mistral-7b-hf`, `llava-hf/llava-v1.6-vicuna-7b-hf`, etc. | | ✅︎ | ✅︎ |
| `LlavaNextVideoForConditionalGeneration` | LLaVA-NeXT-Video | T + V | `llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. | | ✅︎ | ✅︎ |
diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py
index 5bd75a78f2c4b..e4811c023377f 100644
--- a/examples/offline_inference/vision_language.py
+++ b/examples/offline_inference/vision_language.py
@@ -429,6 +429,44 @@ def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
)
+# Nemontron_VL
+def run_nemotron_vl(questions: list[str], modality: str) -> ModelRequestData:
+ model_name = "nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1"
+
+ engine_args = EngineArgs(
+ model=model_name,
+ trust_remote_code=True,
+ max_model_len=8192,
+ limit_mm_per_prompt={modality: 1},
+ )
+
+ assert modality == "image"
+ placeholder = ""
+
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
+ messages = [
+ [{"role": "user", "content": f"{placeholder}\n{question}"}]
+ for question in questions
+ ]
+ prompts = tokenizer.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
+
+ # Stop tokens for InternVL
+ # models variants may have different stop tokens
+ # please refer to the model card for the correct "stop words":
+ # https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py
+ stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
+ stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
+ stop_token_ids = [token_id for token_id in stop_token_ids if token_id is not None]
+
+ return ModelRequestData(
+ engine_args=engine_args,
+ prompts=prompts,
+ stop_token_ids=stop_token_ids,
+ )
+
+
# Keye-VL
def run_keye_vl(questions: list[str], modality: str) -> ModelRequestData:
model_name = "Kwai-Keye/Keye-VL-8B-Preview"
@@ -1186,6 +1224,7 @@ model_example_map = {
"h2ovl_chat": run_h2ovl,
"idefics3": run_idefics3,
"internvl_chat": run_internvl,
+ "nemotron_vl": run_nemotron_vl,
"keye_vl": run_keye_vl,
"kimi_vl": run_kimi_vl,
"llava": run_llava,
diff --git a/requirements/test.in b/requirements/test.in
index e8715afaf4f61..c6c68891d6a6a 100644
--- a/requirements/test.in
+++ b/requirements/test.in
@@ -30,6 +30,7 @@ mamba_ssm # required for plamo2 test
matplotlib # required for qwen-vl test
mistral_common[opencv] >= 1.8.0 # required for voxtral test
num2words # required for smolvlm test
+open_clip_torch==2.32.0 # Required for nemotron_vl test
opencv-python-headless >= 4.11.0 # required for video test
datamodel_code_generator # required for minicpm3 test
lm-eval[api]==0.4.8 # required for model evaluation test
diff --git a/requirements/test.txt b/requirements/test.txt
index 90d8f8ff0bc8b..aadbab03f6fc8 100644
--- a/requirements/test.txt
+++ b/requirements/test.txt
@@ -174,6 +174,8 @@ fsspec==2024.9.0
# fastparquet
# huggingface-hub
# torch
+ftfy==6.3.1
+ # via open-clip-torch
genai-perf==0.0.8
# via -r requirements/test.in
genson==1.3.0
@@ -208,6 +210,7 @@ huggingface-hub==0.33.0
# accelerate
# datasets
# evaluate
+ # open-clip-torch
# peft
# sentence-transformers
# timm
@@ -414,6 +417,8 @@ nvidia-nvjitlink-cu12==12.8.61
# torch
nvidia-nvtx-cu12==12.8.55
# via torch
+open-clip-torch==2.32.0
+ # via -r requirements/test.in
opencensus==0.11.4
# via ray
opencensus-context==0.1.3
@@ -615,6 +620,7 @@ referencing==0.35.1
regex==2024.9.11
# via
# nltk
+ # open-clip-torch
# sacrebleu
# tiktoken
# transformers
@@ -665,6 +671,7 @@ sacrebleu==2.4.3
safetensors==0.4.5
# via
# accelerate
+ # open-clip-torch
# peft
# timm
# transformers
@@ -753,7 +760,9 @@ tiktoken==0.7.0
# lm-eval
# mistral-common
timm==1.0.11
- # via -r requirements/test.in
+ # via
+ # -r requirements/test.in
+ # open-clip-torch
tokenizers==0.21.1
# via
# -r requirements/test.in
@@ -772,6 +781,7 @@ torch==2.7.1+cu128
# lm-eval
# mamba-ssm
# mteb
+ # open-clip-torch
# peft
# runai-model-streamer
# sentence-transformers
@@ -789,6 +799,7 @@ torchaudio==2.7.1+cu128
torchvision==0.22.1+cu128
# via
# -r requirements/test.in
+ # open-clip-torch
# timm
tqdm==4.66.6
# via
@@ -798,6 +809,7 @@ tqdm==4.66.6
# lm-eval
# mteb
# nltk
+ # open-clip-torch
# peft
# pqdm
# sentence-transformers
@@ -863,6 +875,8 @@ virtualenv==20.31.2
# via ray
vocos==0.1.0
# via -r requirements/test.in
+wcwidth==0.2.13
+ # via ftfy
webcolors==24.11.1
# via jsonschema
werkzeug==3.1.3
diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py
index ab21941fae973..fd5842523178f 100644
--- a/tests/models/multimodal/processing/test_common.py
+++ b/tests/models/multimodal/processing/test_common.py
@@ -291,6 +291,7 @@ def _test_processing_correctness_one(
"allenai/Molmo-7B-D-0924",
"allenai/Molmo-7B-O-0924",
"nvidia/NVLM-D-72B",
+ "nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1",
"AIDC-AI/Ovis1.6-Gemma2-9B",
"AIDC-AI/Ovis1.6-Llama3.2-3B",
"AIDC-AI/Ovis2-1B",
diff --git a/tests/models/multimodal/processing/test_nemotron_vl.py b/tests/models/multimodal/processing/test_nemotron_vl.py
new file mode 100644
index 0000000000000..3ce88bc427f5a
--- /dev/null
+++ b/tests/models/multimodal/processing/test_nemotron_vl.py
@@ -0,0 +1,134 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Tests for Nemotron-Nano-VL's multimodal preprocessing kwargs."""
+from collections.abc import Mapping
+from typing import Optional
+
+import pytest
+from PIL import Image
+from transformers import PretrainedConfig
+
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.image import rescale_image_size
+from vllm.multimodal.processing import BaseMultiModalProcessor
+
+from ....conftest import ImageTestAssets
+from ...utils import build_model_context
+
+
+def _get_expected_num_patches(
+ config: PretrainedConfig,
+ image: Image.Image,
+ num_imgs: int,
+ min_num: int,
+ max_num: int,
+):
+ from vllm.model_executor.models.internvl import (
+ calculate_internvl_targets, get_internvl_target_ratios)
+
+ width, height = image.size
+
+ blocks, _, _ = calculate_internvl_targets(
+ orig_width=width,
+ orig_height=height,
+ target_ratios=get_internvl_target_ratios(
+ min_num,
+ max_num,
+ ),
+ image_size=config.force_image_size,
+ use_thumbnail=False,
+ )
+ expected_num_patches = blocks
+
+ if config.use_thumbnail and expected_num_patches > 1:
+ expected_num_patches += 1
+
+ return expected_num_patches
+
+
+def _run_check(
+ processor: BaseMultiModalProcessor,
+ images: list[Image.Image],
+ min_num: int,
+ max_num: int,
+ mm_processor_kwargs: Mapping[str, object],
+):
+ tokenizer = processor.info.get_tokenizer()
+ config = processor.info.get_hf_config()
+ image_processor = processor.info.get_image_processor()
+
+ config.use_thumbnail = image_processor.use_thumbnail
+ prompt = "" * len(images)
+ mm_data = {"image": images}
+
+ total_expected_num_patches = sum(
+ _get_expected_num_patches(config, image, len(images), min_num, max_num)
+ for image in images)
+ print(total_expected_num_patches)
+ processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
+
+ # Ensure we have the right number of placeholders per num_crops size
+ image_token_id = tokenizer.convert_tokens_to_ids("")
+ img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
+ pixel_shape = processed_inputs["mm_kwargs"]["pixel_values_flat"].shape
+ print("Image token count:", img_tok_count, "Pixel shape:", pixel_shape)
+ assert img_tok_count == 256 * total_expected_num_patches
+ assert pixel_shape[0] == total_expected_num_patches
+
+
+@pytest.mark.parametrize("model_id",
+ ["nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1"])
+@pytest.mark.parametrize(
+ "size_factors",
+ [
+ # Single-scale
+ [1.0],
+ # Single-scale, batched
+ [1.0, 1.0, 1.0],
+ # Multi-scale
+ [0.25, 0.5, 1.0],
+ [4.0, 2.0, 1.0],
+ ],
+)
+@pytest.mark.parametrize(
+ ("min_dynamic_patch", "max_dynamic_patch"),
+ [(1, 1), (1, 2), (1, 4), (1, 8), (2, 4), (4, 8)],
+)
+@pytest.mark.parametrize("dynamic_image_size", [True, False])
+@pytest.mark.parametrize("kwargs_on_init", [True, False])
+def test_processor_override(
+ model_id: str,
+ image_assets: ImageTestAssets,
+ size_factors: list[int],
+ min_dynamic_patch: int,
+ max_dynamic_patch: int,
+ dynamic_image_size: Optional[bool],
+ kwargs_on_init: bool,
+):
+ mm_processor_kwargs = {
+ "min_dynamic_patch": min_dynamic_patch,
+ "max_dynamic_patch": max_dynamic_patch,
+ "dynamic_image_size": dynamic_image_size,
+ }
+
+ ctx = build_model_context(
+ model_id,
+ mm_processor_kwargs=mm_processor_kwargs if kwargs_on_init else None,
+ limit_mm_per_prompt={"image": len(size_factors)},
+ )
+ processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
+ hf_processor_mm_kwargs = {} if kwargs_on_init else mm_processor_kwargs
+
+ min_num = min_dynamic_patch if dynamic_image_size else 1
+ max_num = max_dynamic_patch if dynamic_image_size else 1
+
+ _run_check(
+ processor,
+ [
+ rescale_image_size(image_assets[0].pil_image, f)
+ for f in size_factors
+ ],
+ min_num,
+ max_num,
+ hf_processor_mm_kwargs,
+ )
diff --git a/tests/models/registry.py b/tests/models/registry.py
index d2e70e291df3e..2adfa859a1c79 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -401,6 +401,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
trust_remote_code=True),
"NVLM_D": _HfExamplesInfo("nvidia/NVLM-D-72B",
trust_remote_code=True),
+ "Llama_Nemotron_Nano_VL" : _HfExamplesInfo("nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1", # noqa: E501
+ trust_remote_code=True),
"PaliGemmaForConditionalGeneration": _HfExamplesInfo("google/paligemma-3b-mix-224", # noqa: E501
extras={"v2": "google/paligemma2-3b-ft-docci-448"}), # noqa: E501
"Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct",
diff --git a/vllm/model_executor/models/nemotron_vl.py b/vllm/model_executor/models/nemotron_vl.py
new file mode 100644
index 0000000000000..5d0513d707413
--- /dev/null
+++ b/vllm/model_executor/models/nemotron_vl.py
@@ -0,0 +1,505 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_internvl_chat.py
+# --------------------------------------------------------
+# InternVL
+# Copyright (c) 2023 OpenGVLab
+# Licensed under The MIT License [see LICENSE for details]
+# --------------------------------------------------------
+from abc import ABC
+from collections.abc import Iterable
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from PIL import Image
+from transformers import AutoModel, PretrainedConfig
+from transformers.image_processing_utils_fast import BaseImageProcessorFast
+
+from vllm.config import VllmConfig
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.layers.quantization.awq import AWQConfig
+from vllm.model_executor.models.internvl import (
+ BaseInternVLDummyInputsBuilder, BaseInternVLMultiModalProcessor,
+ BaseInternVLProcessingInfo, InternVLImageEmbeddingInputs,
+ InternVLImageInputs, InternVLImagePixelInputs, InternVLProcessor)
+from vllm.model_executor.models.module_mapping import MultiModelKeys
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import NestedTensors
+from vllm.multimodal.processing import PromptUpdateDetails
+from vllm.sequence import IntermediateTensors
+from vllm.transformers_utils.processor import (
+ cached_image_processor_from_config)
+from vllm.transformers_utils.tokenizer import AnyTokenizer
+
+from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
+ SupportsMultiModal, SupportsPP)
+from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
+ maybe_prefix, merge_multimodal_embeddings)
+
+IMG_START = '
'
+IMG_END = ''
+IMG_CONTEXT = ''
+
+
+class NemotronVLProcessor(InternVLProcessor):
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ tokenizer: AnyTokenizer,
+ image_processor: BaseImageProcessorFast,
+ *,
+ min_dynamic_patch: Optional[int] = None,
+ max_dynamic_patch: Optional[int] = None,
+ dynamic_image_size: Optional[bool] = None,
+ ) -> None:
+ ABC.__init__(self)
+ self.config = config
+ self.tokenizer = tokenizer
+ self.image_processor = image_processor
+ image_size: int = config.force_image_size
+ patch_size: int = config.patch_size
+
+ if min_dynamic_patch is None:
+ min_dynamic_patch = 1
+ assert isinstance(min_dynamic_patch, int)
+
+ if max_dynamic_patch is None:
+ max_dynamic_patch = self.image_processor.max_num_tiles
+ assert isinstance(max_dynamic_patch, int)
+
+ if dynamic_image_size is None:
+ dynamic_image_size = True
+ assert isinstance(dynamic_image_size, bool)
+
+ self.num_image_token = int(
+ (image_size // patch_size)**2 * (config.downsample_ratio**2))
+ self.image_size = image_size
+ self.min_dynamic_patch = min_dynamic_patch
+ self.max_dynamic_patch = max_dynamic_patch
+ self.dynamic_image_size = dynamic_image_size
+ self.use_thumbnail: bool = self.image_processor.use_thumbnail
+
+ @property
+ def image_token_id(self) -> int:
+ return self.tokenizer.get_vocab()[IMG_CONTEXT]
+
+ def _preprocess_image(
+ self,
+ text: list[str],
+ images: list[Image.Image],
+ min_dynamic_patch: Optional[int] = None,
+ max_dynamic_patch: Optional[int] = None,
+ dynamic_image_size: Optional[bool] = None,
+ ) -> tuple[list[str], dict[str, torch.Tensor]]:
+ if len(images) == 0:
+ image_inputs = {}
+ else:
+ pixel_values_lst = self._images_to_pixel_values_lst(
+ images,
+ min_dynamic_patch=min_dynamic_patch,
+ max_dynamic_patch=max_dynamic_patch,
+ dynamic_image_size=dynamic_image_size,
+ )
+ image_inputs: dict[str, NestedTensors] = {
+ "pixel_values_flat":
+ torch.cat(pixel_values_lst),
+ "image_num_patches":
+ torch.tensor([len(item) for item in pixel_values_lst]),
+ }
+
+ for pixel_values in pixel_values_lst:
+ num_patches = pixel_values.shape[0]
+ feature_size = num_patches * self.num_image_token
+ image_repl = self.get_image_repl(feature_size, num_patches)
+ NVL_IMAGE_CONTEXT = image_repl.full.replace(
+ "", "")
+ text = [
+ t.replace('', NVL_IMAGE_CONTEXT, 1) for t in text
+ ]
+ text = [t.replace("", IMG_CONTEXT) for t in text]
+ return text, image_inputs
+
+ def get_image_repl(
+ self,
+ feature_size: int,
+ num_patches: Optional[int],
+ ) -> PromptUpdateDetails[str]:
+ repl_features = IMG_CONTEXT * feature_size
+ repl_full = IMG_START + repl_features + IMG_END
+
+ return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
+
+
+class NemotronVLProcessingInfo(BaseInternVLProcessingInfo):
+ """Processing info for Nemotron VL models."""
+
+ def get_hf_processor(
+ self,
+ *,
+ min_dynamic_patch: Optional[int] = None,
+ max_dynamic_patch: Optional[int] = None,
+ dynamic_image_size: Optional[bool] = None,
+ **kwargs: object,
+ ) -> NemotronVLProcessor:
+ if min_dynamic_patch is not None:
+ kwargs["min_dynamic_patch"] = min_dynamic_patch
+ if max_dynamic_patch is not None:
+ kwargs["max_dynamic_patch"] = max_dynamic_patch
+ if dynamic_image_size is not None:
+ kwargs["dynamic_image_size"] = dynamic_image_size
+
+ image_processor = self.get_image_processor()
+ return self.ctx.init_processor(
+ NemotronVLProcessor,
+ config=self.get_hf_config(),
+ tokenizer=self.get_tokenizer(),
+ image_processor=image_processor,
+ **kwargs,
+ )
+
+ def get_image_processor(
+ self,
+ **kwargs: object,
+ ):
+ return cached_image_processor_from_config(
+ self.ctx.model_config,
+ **kwargs,
+ )
+
+
+@MULTIMODAL_REGISTRY.register_processor(
+ BaseInternVLMultiModalProcessor[NemotronVLProcessingInfo],
+ info=NemotronVLProcessingInfo,
+ dummy_inputs=BaseInternVLDummyInputsBuilder[NemotronVLProcessingInfo])
+class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
+ SupportsLoRA):
+
+ @classmethod
+ def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
+ if modality.startswith("image"):
+ return ""
+
+ raise ValueError("Only image modality is supported")
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
+ super().__init__()
+
+ config = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+ multimodal_config = vllm_config.model_config.multimodal_config
+
+ self.config = config
+ self.multimodal_config = multimodal_config
+ self._patch_quant_config(config, quant_config)
+
+ image_size = config.force_image_size or config.vision_config.image_size
+ patch_size = config.vision_config.patch_size
+ self.patch_size = patch_size
+ self.num_image_token = int(
+ (image_size // patch_size)**2 * (config.downsample_ratio**2))
+ self.downsample_ratio = config.downsample_ratio
+ self.ps_version = config.ps_version
+
+ self.llm_arch_name = config.text_config.architectures[0]
+ self.vision_model = self._init_vision_model(
+ config,
+ quant_config=quant_config,
+ prefix=maybe_prefix(prefix, "vision_model"),
+ )
+
+ self.language_model = init_vllm_registered_model(
+ vllm_config=vllm_config,
+ hf_config=config.text_config,
+ prefix=maybe_prefix(prefix, "language_model"),
+ )
+
+ self.mlp1 = self._init_mlp1(config)
+
+ self.img_context_token_id = None
+
+ self.visual_token_mask = None
+ self.make_empty_intermediate_tensors = (
+ self.language_model.make_empty_intermediate_tensors)
+
+ def _patch_quant_config(self, config: PretrainedConfig,
+ quant_config: QuantizationConfig):
+ # the awq models from OpenGVLab missing `modules_to_not_convert`
+ # patch the quant_config to add `modules_to_not_convert` back
+ if isinstance(quant_config, AWQConfig):
+ text_config = config.text_config
+ llm_quant_config = getattr(text_config, "quantization_config",
+ None)
+ if (not quant_config.modules_to_not_convert) and \
+ (llm_quant_config is not None):
+ quant_config.modules_to_not_convert.append("vision_model")
+
+ def _init_vision_model(
+ self,
+ config: PretrainedConfig,
+ quant_config: Optional[QuantizationConfig],
+ *,
+ prefix: str,
+ ):
+ return AutoModel.from_config(config.vision_config,
+ trust_remote_code=True)
+
+ def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
+ vit_hidden_size = config.vit_hidden_size
+ vision_projection_hidden_size = config.projector_hidden_size
+ llm_hidden_size = config.text_config.hidden_size
+
+ return nn.Sequential(
+ nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2,
+ bias=True),
+ nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,
+ vision_projection_hidden_size,
+ bias=True),
+ nn.GELU(),
+ nn.Linear(vision_projection_hidden_size, llm_hidden_size),
+ )
+
+ def pixel_shuffle(self, x, scale_factor=0.5):
+ n, w, h, c = x.size()
+ # N, W, H, C --> N, W, H * scale, C // scale
+ x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
+ # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
+ x = x.permute(0, 2, 1, 3).contiguous()
+ x = x.view(n, int(h * scale_factor), int(w * scale_factor),
+ int(c / (scale_factor * scale_factor)))
+ if self.ps_version == 'v1':
+ pass
+ else:
+ x = x.permute(0, 2, 1, 3).contiguous()
+ return x
+
+ def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
+ # https://huggingface.co/nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1/blob/main/modeling.py#L177
+ vit_embeds = self.vision_model(x=pixel_values).features
+ vit_embeds = vit_embeds.to(dtype=torch.bfloat16)
+
+ h = w = int(vit_embeds.shape[1]**0.5)
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
+ vit_embeds = self.pixel_shuffle(vit_embeds,
+ scale_factor=self.downsample_ratio)
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1,
+ vit_embeds.shape[-1])
+ vit_embeds = self.mlp1(vit_embeds)
+ return vit_embeds
+
+ def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
+
+ #use force_image_size to get image_size
+ h = w = self.config.force_image_size
+ expected_dims = (3, h, w)
+
+ def _validate_shape(d: torch.Tensor):
+ actual_dims = tuple(d.shape)
+
+ if actual_dims != expected_dims:
+ expected_expr = str(expected_dims)
+ raise ValueError(
+ "The expected shape of pixel values per image per batch "
+ f" per patch is {expected_expr}. "
+ f"You supplied {tuple(d.shape)}.")
+
+ for d in data:
+ _validate_shape(d)
+
+ return data
+
+ def _parse_and_validate_image_input(
+ self, **kwargs: object) -> Optional[InternVLImageInputs]:
+ pixel_values_flat = kwargs.pop("pixel_values_flat", None)
+ image_num_patches = kwargs.pop("image_num_patches", None)
+ image_embeds = kwargs.pop("image_embeds", None)
+
+ if pixel_values_flat is None and image_embeds is None:
+ return None
+
+ if image_embeds is not None:
+ if not isinstance(image_embeds, (torch.Tensor, list)):
+ raise ValueError("Incorrect type of image embeddings. "
+ f"Got type: {type(image_embeds)}")
+
+ return InternVLImageEmbeddingInputs(
+ type="image_embeds",
+ data=flatten_bn(image_embeds),
+ )
+
+ image_token_id = kwargs["image_token_id"]
+ assert isinstance(image_token_id, torch.Tensor)
+ self.img_context_token_id = image_token_id.flatten().unique().item()
+
+ if pixel_values_flat is not None:
+ if not isinstance(pixel_values_flat, (torch.Tensor, list)):
+ raise ValueError("Incorrect type of pixel values. "
+ f"Got type: {type(pixel_values_flat)}")
+
+ if not isinstance(image_num_patches, (torch.Tensor, list)):
+ raise ValueError("Incorrect type of image_num_patches. "
+ f"Got type: {type(image_num_patches)}")
+
+ pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
+ image_num_patches = flatten_bn(image_num_patches, concat=True)
+
+ return InternVLImagePixelInputs(
+ type="pixel_values",
+ pixel_values_flat=self._validate_pixel_values(
+ pixel_values_flat),
+ num_patches=image_num_patches,
+ )
+
+ raise AssertionError("This line should be unreachable.")
+
+ def _process_image_input(
+ self,
+ image_input: InternVLImageInputs,
+ ) -> tuple[torch.Tensor, ...]:
+ if image_input["type"] == "image_embeds":
+ return image_input["data"]
+
+ assert self.vision_model is not None
+
+ image_embeds = self.extract_feature(image_input["pixel_values_flat"])
+
+ num_patches = image_input["num_patches"]
+
+ # Only one image in the current batch
+ if len(num_patches) == 1:
+ return (image_embeds.view(-1,
+ self.config.text_config.hidden_size), )
+
+ # NOTE: Image embeddings are split into separate tensors for each image
+ # by the size of each embedding.
+ feature_size = image_embeds.shape[1]
+ image_embeds = image_embeds.view(-1,
+ self.config.text_config.hidden_size)
+ image_feature_sizes = [
+ num_patches * feature_size for num_patches in num_patches
+ ]
+ return image_embeds.split(image_feature_sizes)
+
+ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
+ modalities = {}
+
+ # Preserve the order of modalities if there are multiple of them
+ # from the order of kwargs.
+ for input_key in kwargs:
+ if input_key in ("pixel_values_flat",
+ "image_embeds") and "images" not in modalities:
+ modalities["images"] = self._parse_and_validate_image_input(
+ **kwargs)
+
+ return modalities
+
+ def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
+ self.visual_token_mask = None
+
+ def get_language_model(self) -> torch.nn.Module:
+ return self.language_model
+
+ def get_multimodal_embeddings(self,
+ **kwargs: object) -> MultiModalEmbeddings:
+
+ modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
+ if not modalities:
+ return []
+
+ # The result multimodal_embeddings is tuple of tensors, with each
+ # tensor correspoending to a multimodal data item (image).
+ multimodal_embeddings: tuple[torch.Tensor, ...] = ()
+
+ # NOTE: It is important to iterate over the keys in this dictionary
+ # to preserve the order of the modalities.
+ for modality in modalities:
+ if modality == "images":
+ image_input = modalities["images"]
+ vision_embeddings = self._process_image_input(image_input)
+ multimodal_embeddings += vision_embeddings
+
+ return multimodal_embeddings
+
+ def get_input_embeddings(
+ self,
+ input_ids: torch.Tensor,
+ multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
+ ) -> torch.Tensor:
+ inputs_embeds = self.language_model.get_input_embeddings(input_ids)
+ if multimodal_embeddings is not None \
+ and len(multimodal_embeddings) != 0:
+ context_token_ids = [self.img_context_token_id]
+ assert len(context_token_ids) >= 1
+ self._set_visual_token_mask(input_ids)
+ inputs_embeds = merge_multimodal_embeddings(
+ input_ids,
+ inputs_embeds,
+ multimodal_embeddings,
+ context_token_ids,
+ )
+ return inputs_embeds
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ **kwargs: object,
+ ) -> IntermediateTensors:
+
+ if intermediate_tensors is not None:
+ input_ids = None
+ inputs_embeds = None
+
+ # NOTE: In v1, inputs_embeds is always generated at model runner, this
+ # condition is for v0 compatibility.
+ elif inputs_embeds is None:
+ vision_embeddings = self.get_multimodal_embeddings(**kwargs)
+ inputs_embeds = self.get_input_embeddings(input_ids,
+ vision_embeddings)
+ input_ids = None
+
+ forward_kwargs = {
+ "input_ids": input_ids,
+ "positions": positions,
+ "intermediate_tensors": intermediate_tensors,
+ "inputs_embeds": inputs_embeds,
+ }
+
+ # Only required if the model is mono-architecture
+ if self.visual_token_mask is not None:
+ forward_kwargs.update(
+ {"visual_token_mask": self.visual_token_mask})
+ self.visual_token_mask = None
+
+ hidden_states = self.language_model.model(**forward_kwargs)
+ return hidden_states
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[torch.Tensor]:
+ return self.language_model.compute_logits(hidden_states,
+ sampling_metadata)
+
+ def load_weights(self, weights: Iterable[tuple[str,
+ torch.Tensor]]) -> set[str]:
+ ## Ignore registered_buffers
+ ## see https://huggingface.co/nvidia/C-RADIOv2-H/blob/main/input_conditioner.py#L28 # noqa: E501
+ skip_substrs = ["norm_mean", "norm_std"]
+ loader = AutoWeightsLoader(self, skip_substrs=skip_substrs)
+ return loader.load_weights(weights)
+
+ def get_mm_mapping(self) -> MultiModelKeys:
+ """
+ Get the module prefix in multimodal models
+ """
+ return MultiModelKeys.from_string_field(
+ language_model="language_model",
+ connector="mlp1",
+ tower_model="vision_model")
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index bc936500bdc89..52fdb91089167 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -206,6 +206,7 @@ _MULTIMODAL_MODELS = {
"SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"), # noqa: E501
"KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
"KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501
+ "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
"LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
"LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501
diff --git a/vllm/transformers_utils/configs/nemotron.py b/vllm/transformers_utils/configs/nemotron.py
index d65b572dc7f22..9a7243b1262c0 100644
--- a/vllm/transformers_utils/configs/nemotron.py
+++ b/vllm/transformers_utils/configs/nemotron.py
@@ -202,4 +202,4 @@ class NemotronConfig(PretrainedConfig):
rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(
"`rope_scaling`'s factor field must be a float > 1, got "
- f"{rope_scaling_factor}")
+ f"{rope_scaling_factor}")
\ No newline at end of file