mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 20:15:01 +08:00
[VLM] Add Nemotron-Nano-VL-8B-V1 support (#20349)
Signed-off-by: Kyle Huang <kylhuang@nvidia.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
parent
5a7fb3ab9e
commit
4ef00b5cac
@ -95,7 +95,7 @@ WORKDIR /workspace/vllm
|
||||
RUN --mount=type=bind,src=requirements/test.in,target=requirements/test.in \
|
||||
cp requirements/test.in requirements/cpu-test.in && \
|
||||
sed -i '/mamba_ssm/d' requirements/cpu-test.in && \
|
||||
sed -i 's/torch==.*/torch==2.6.0/g' requirements/cpu-test.in && \
|
||||
sed -i 's/^torch==.*/torch==2.6.0/g' requirements/cpu-test.in && \
|
||||
sed -i 's/torchaudio.*/torchaudio/g' requirements/cpu-test.in && \
|
||||
sed -i 's/torchvision.*/torchvision/g' requirements/cpu-test.in && \
|
||||
uv pip compile requirements/cpu-test.in -o requirements/cpu-test.txt --index-strategy unsafe-best-match --torch-backend cpu
|
||||
|
||||
@ -581,6 +581,7 @@ Specified using `--task generate`.
|
||||
| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-8B-Preview` | | | ✅︎ |
|
||||
| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | | ✅︎ |
|
||||
| `Llama4ForConditionalGeneration` | Llama 4 | T + I<sup>+</sup> | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Llama_Nemotron_Nano_VL` | Llama Nemotron Nano VL | T + I<sup>E+</sup> | `nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1` | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `LlavaForConditionalGeneration` | LLaVA-1.5, Pixtral (HF Transformers) | T + I<sup>E+</sup> | `llava-hf/llava-1.5-7b-hf`, `TIGER-Lab/Mantis-8B-siglip-llama3` (see note), `mistral-community/pixtral-12b`, etc. | | ✅︎ | ✅︎ |
|
||||
| `LlavaNextForConditionalGeneration` | LLaVA-NeXT | T + I<sup>E+</sup> | `llava-hf/llava-v1.6-mistral-7b-hf`, `llava-hf/llava-v1.6-vicuna-7b-hf`, etc. | | ✅︎ | ✅︎ |
|
||||
| `LlavaNextVideoForConditionalGeneration` | LLaVA-NeXT-Video | T + V | `llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. | | ✅︎ | ✅︎ |
|
||||
|
||||
@ -429,6 +429,44 @@ def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
# Nemontron_VL
|
||||
def run_nemotron_vl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
model_name = "nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_model_len=8192,
|
||||
limit_mm_per_prompt={modality: 1},
|
||||
)
|
||||
|
||||
assert modality == "image"
|
||||
placeholder = "<image>"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||
messages = [
|
||||
[{"role": "user", "content": f"{placeholder}\n{question}"}]
|
||||
for question in questions
|
||||
]
|
||||
prompts = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
# Stop tokens for InternVL
|
||||
# models variants may have different stop tokens
|
||||
# please refer to the model card for the correct "stop words":
|
||||
# https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py
|
||||
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
|
||||
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
|
||||
stop_token_ids = [token_id for token_id in stop_token_ids if token_id is not None]
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
stop_token_ids=stop_token_ids,
|
||||
)
|
||||
|
||||
|
||||
# Keye-VL
|
||||
def run_keye_vl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
model_name = "Kwai-Keye/Keye-VL-8B-Preview"
|
||||
@ -1186,6 +1224,7 @@ model_example_map = {
|
||||
"h2ovl_chat": run_h2ovl,
|
||||
"idefics3": run_idefics3,
|
||||
"internvl_chat": run_internvl,
|
||||
"nemotron_vl": run_nemotron_vl,
|
||||
"keye_vl": run_keye_vl,
|
||||
"kimi_vl": run_kimi_vl,
|
||||
"llava": run_llava,
|
||||
|
||||
@ -30,6 +30,7 @@ mamba_ssm # required for plamo2 test
|
||||
matplotlib # required for qwen-vl test
|
||||
mistral_common[opencv] >= 1.8.0 # required for voxtral test
|
||||
num2words # required for smolvlm test
|
||||
open_clip_torch==2.32.0 # Required for nemotron_vl test
|
||||
opencv-python-headless >= 4.11.0 # required for video test
|
||||
datamodel_code_generator # required for minicpm3 test
|
||||
lm-eval[api]==0.4.8 # required for model evaluation test
|
||||
|
||||
@ -174,6 +174,8 @@ fsspec==2024.9.0
|
||||
# fastparquet
|
||||
# huggingface-hub
|
||||
# torch
|
||||
ftfy==6.3.1
|
||||
# via open-clip-torch
|
||||
genai-perf==0.0.8
|
||||
# via -r requirements/test.in
|
||||
genson==1.3.0
|
||||
@ -208,6 +210,7 @@ huggingface-hub==0.33.0
|
||||
# accelerate
|
||||
# datasets
|
||||
# evaluate
|
||||
# open-clip-torch
|
||||
# peft
|
||||
# sentence-transformers
|
||||
# timm
|
||||
@ -414,6 +417,8 @@ nvidia-nvjitlink-cu12==12.8.61
|
||||
# torch
|
||||
nvidia-nvtx-cu12==12.8.55
|
||||
# via torch
|
||||
open-clip-torch==2.32.0
|
||||
# via -r requirements/test.in
|
||||
opencensus==0.11.4
|
||||
# via ray
|
||||
opencensus-context==0.1.3
|
||||
@ -615,6 +620,7 @@ referencing==0.35.1
|
||||
regex==2024.9.11
|
||||
# via
|
||||
# nltk
|
||||
# open-clip-torch
|
||||
# sacrebleu
|
||||
# tiktoken
|
||||
# transformers
|
||||
@ -665,6 +671,7 @@ sacrebleu==2.4.3
|
||||
safetensors==0.4.5
|
||||
# via
|
||||
# accelerate
|
||||
# open-clip-torch
|
||||
# peft
|
||||
# timm
|
||||
# transformers
|
||||
@ -753,7 +760,9 @@ tiktoken==0.7.0
|
||||
# lm-eval
|
||||
# mistral-common
|
||||
timm==1.0.11
|
||||
# via -r requirements/test.in
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# open-clip-torch
|
||||
tokenizers==0.21.1
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
@ -772,6 +781,7 @@ torch==2.7.1+cu128
|
||||
# lm-eval
|
||||
# mamba-ssm
|
||||
# mteb
|
||||
# open-clip-torch
|
||||
# peft
|
||||
# runai-model-streamer
|
||||
# sentence-transformers
|
||||
@ -789,6 +799,7 @@ torchaudio==2.7.1+cu128
|
||||
torchvision==0.22.1+cu128
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# open-clip-torch
|
||||
# timm
|
||||
tqdm==4.66.6
|
||||
# via
|
||||
@ -798,6 +809,7 @@ tqdm==4.66.6
|
||||
# lm-eval
|
||||
# mteb
|
||||
# nltk
|
||||
# open-clip-torch
|
||||
# peft
|
||||
# pqdm
|
||||
# sentence-transformers
|
||||
@ -863,6 +875,8 @@ virtualenv==20.31.2
|
||||
# via ray
|
||||
vocos==0.1.0
|
||||
# via -r requirements/test.in
|
||||
wcwidth==0.2.13
|
||||
# via ftfy
|
||||
webcolors==24.11.1
|
||||
# via jsonschema
|
||||
werkzeug==3.1.3
|
||||
|
||||
@ -291,6 +291,7 @@ def _test_processing_correctness_one(
|
||||
"allenai/Molmo-7B-D-0924",
|
||||
"allenai/Molmo-7B-O-0924",
|
||||
"nvidia/NVLM-D-72B",
|
||||
"nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1",
|
||||
"AIDC-AI/Ovis1.6-Gemma2-9B",
|
||||
"AIDC-AI/Ovis1.6-Llama3.2-3B",
|
||||
"AIDC-AI/Ovis2-1B",
|
||||
|
||||
134
tests/models/multimodal/processing/test_nemotron_vl.py
Normal file
134
tests/models/multimodal/processing/test_nemotron_vl.py
Normal file
@ -0,0 +1,134 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for Nemotron-Nano-VL's multimodal preprocessing kwargs."""
|
||||
from collections.abc import Mapping
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from PIL import Image
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.image import rescale_image_size
|
||||
from vllm.multimodal.processing import BaseMultiModalProcessor
|
||||
|
||||
from ....conftest import ImageTestAssets
|
||||
from ...utils import build_model_context
|
||||
|
||||
|
||||
def _get_expected_num_patches(
|
||||
config: PretrainedConfig,
|
||||
image: Image.Image,
|
||||
num_imgs: int,
|
||||
min_num: int,
|
||||
max_num: int,
|
||||
):
|
||||
from vllm.model_executor.models.internvl import (
|
||||
calculate_internvl_targets, get_internvl_target_ratios)
|
||||
|
||||
width, height = image.size
|
||||
|
||||
blocks, _, _ = calculate_internvl_targets(
|
||||
orig_width=width,
|
||||
orig_height=height,
|
||||
target_ratios=get_internvl_target_ratios(
|
||||
min_num,
|
||||
max_num,
|
||||
),
|
||||
image_size=config.force_image_size,
|
||||
use_thumbnail=False,
|
||||
)
|
||||
expected_num_patches = blocks
|
||||
|
||||
if config.use_thumbnail and expected_num_patches > 1:
|
||||
expected_num_patches += 1
|
||||
|
||||
return expected_num_patches
|
||||
|
||||
|
||||
def _run_check(
|
||||
processor: BaseMultiModalProcessor,
|
||||
images: list[Image.Image],
|
||||
min_num: int,
|
||||
max_num: int,
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
):
|
||||
tokenizer = processor.info.get_tokenizer()
|
||||
config = processor.info.get_hf_config()
|
||||
image_processor = processor.info.get_image_processor()
|
||||
|
||||
config.use_thumbnail = image_processor.use_thumbnail
|
||||
prompt = "<image>" * len(images)
|
||||
mm_data = {"image": images}
|
||||
|
||||
total_expected_num_patches = sum(
|
||||
_get_expected_num_patches(config, image, len(images), min_num, max_num)
|
||||
for image in images)
|
||||
print(total_expected_num_patches)
|
||||
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
|
||||
|
||||
# Ensure we have the right number of placeholders per num_crops size
|
||||
image_token_id = tokenizer.convert_tokens_to_ids("<image>")
|
||||
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
|
||||
pixel_shape = processed_inputs["mm_kwargs"]["pixel_values_flat"].shape
|
||||
print("Image token count:", img_tok_count, "Pixel shape:", pixel_shape)
|
||||
assert img_tok_count == 256 * total_expected_num_patches
|
||||
assert pixel_shape[0] == total_expected_num_patches
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id",
|
||||
["nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1"])
|
||||
@pytest.mark.parametrize(
|
||||
"size_factors",
|
||||
[
|
||||
# Single-scale
|
||||
[1.0],
|
||||
# Single-scale, batched
|
||||
[1.0, 1.0, 1.0],
|
||||
# Multi-scale
|
||||
[0.25, 0.5, 1.0],
|
||||
[4.0, 2.0, 1.0],
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
("min_dynamic_patch", "max_dynamic_patch"),
|
||||
[(1, 1), (1, 2), (1, 4), (1, 8), (2, 4), (4, 8)],
|
||||
)
|
||||
@pytest.mark.parametrize("dynamic_image_size", [True, False])
|
||||
@pytest.mark.parametrize("kwargs_on_init", [True, False])
|
||||
def test_processor_override(
|
||||
model_id: str,
|
||||
image_assets: ImageTestAssets,
|
||||
size_factors: list[int],
|
||||
min_dynamic_patch: int,
|
||||
max_dynamic_patch: int,
|
||||
dynamic_image_size: Optional[bool],
|
||||
kwargs_on_init: bool,
|
||||
):
|
||||
mm_processor_kwargs = {
|
||||
"min_dynamic_patch": min_dynamic_patch,
|
||||
"max_dynamic_patch": max_dynamic_patch,
|
||||
"dynamic_image_size": dynamic_image_size,
|
||||
}
|
||||
|
||||
ctx = build_model_context(
|
||||
model_id,
|
||||
mm_processor_kwargs=mm_processor_kwargs if kwargs_on_init else None,
|
||||
limit_mm_per_prompt={"image": len(size_factors)},
|
||||
)
|
||||
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
|
||||
hf_processor_mm_kwargs = {} if kwargs_on_init else mm_processor_kwargs
|
||||
|
||||
min_num = min_dynamic_patch if dynamic_image_size else 1
|
||||
max_num = max_dynamic_patch if dynamic_image_size else 1
|
||||
|
||||
_run_check(
|
||||
processor,
|
||||
[
|
||||
rescale_image_size(image_assets[0].pil_image, f)
|
||||
for f in size_factors
|
||||
],
|
||||
min_num,
|
||||
max_num,
|
||||
hf_processor_mm_kwargs,
|
||||
)
|
||||
@ -401,6 +401,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
trust_remote_code=True),
|
||||
"NVLM_D": _HfExamplesInfo("nvidia/NVLM-D-72B",
|
||||
trust_remote_code=True),
|
||||
"Llama_Nemotron_Nano_VL" : _HfExamplesInfo("nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1", # noqa: E501
|
||||
trust_remote_code=True),
|
||||
"PaliGemmaForConditionalGeneration": _HfExamplesInfo("google/paligemma-3b-mix-224", # noqa: E501
|
||||
extras={"v2": "google/paligemma2-3b-ft-docci-448"}), # noqa: E501
|
||||
"Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct",
|
||||
|
||||
505
vllm/model_executor/models/nemotron_vl.py
Normal file
505
vllm/model_executor/models/nemotron_vl.py
Normal file
@ -0,0 +1,505 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_internvl_chat.py
|
||||
# --------------------------------------------------------
|
||||
# InternVL
|
||||
# Copyright (c) 2023 OpenGVLab
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
from abc import ABC
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
from transformers import AutoModel, PretrainedConfig
|
||||
from transformers.image_processing_utils_fast import BaseImageProcessorFast
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
from vllm.model_executor.models.internvl import (
|
||||
BaseInternVLDummyInputsBuilder, BaseInternVLMultiModalProcessor,
|
||||
BaseInternVLProcessingInfo, InternVLImageEmbeddingInputs,
|
||||
InternVLImageInputs, InternVLImagePixelInputs, InternVLProcessor)
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.multimodal.processing import PromptUpdateDetails
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.processor import (
|
||||
cached_image_processor_from_config)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
|
||||
IMG_START = '<img>'
|
||||
IMG_END = '</img>'
|
||||
IMG_CONTEXT = '<image>'
|
||||
|
||||
|
||||
class NemotronVLProcessor(InternVLProcessor):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
tokenizer: AnyTokenizer,
|
||||
image_processor: BaseImageProcessorFast,
|
||||
*,
|
||||
min_dynamic_patch: Optional[int] = None,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
) -> None:
|
||||
ABC.__init__(self)
|
||||
self.config = config
|
||||
self.tokenizer = tokenizer
|
||||
self.image_processor = image_processor
|
||||
image_size: int = config.force_image_size
|
||||
patch_size: int = config.patch_size
|
||||
|
||||
if min_dynamic_patch is None:
|
||||
min_dynamic_patch = 1
|
||||
assert isinstance(min_dynamic_patch, int)
|
||||
|
||||
if max_dynamic_patch is None:
|
||||
max_dynamic_patch = self.image_processor.max_num_tiles
|
||||
assert isinstance(max_dynamic_patch, int)
|
||||
|
||||
if dynamic_image_size is None:
|
||||
dynamic_image_size = True
|
||||
assert isinstance(dynamic_image_size, bool)
|
||||
|
||||
self.num_image_token = int(
|
||||
(image_size // patch_size)**2 * (config.downsample_ratio**2))
|
||||
self.image_size = image_size
|
||||
self.min_dynamic_patch = min_dynamic_patch
|
||||
self.max_dynamic_patch = max_dynamic_patch
|
||||
self.dynamic_image_size = dynamic_image_size
|
||||
self.use_thumbnail: bool = self.image_processor.use_thumbnail
|
||||
|
||||
@property
|
||||
def image_token_id(self) -> int:
|
||||
return self.tokenizer.get_vocab()[IMG_CONTEXT]
|
||||
|
||||
def _preprocess_image(
|
||||
self,
|
||||
text: list[str],
|
||||
images: list[Image.Image],
|
||||
min_dynamic_patch: Optional[int] = None,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
) -> tuple[list[str], dict[str, torch.Tensor]]:
|
||||
if len(images) == 0:
|
||||
image_inputs = {}
|
||||
else:
|
||||
pixel_values_lst = self._images_to_pixel_values_lst(
|
||||
images,
|
||||
min_dynamic_patch=min_dynamic_patch,
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
)
|
||||
image_inputs: dict[str, NestedTensors] = {
|
||||
"pixel_values_flat":
|
||||
torch.cat(pixel_values_lst),
|
||||
"image_num_patches":
|
||||
torch.tensor([len(item) for item in pixel_values_lst]),
|
||||
}
|
||||
|
||||
for pixel_values in pixel_values_lst:
|
||||
num_patches = pixel_values.shape[0]
|
||||
feature_size = num_patches * self.num_image_token
|
||||
image_repl = self.get_image_repl(feature_size, num_patches)
|
||||
NVL_IMAGE_CONTEXT = image_repl.full.replace(
|
||||
"<image>", "<NVL_IMG_CONTEXT>")
|
||||
text = [
|
||||
t.replace('<image>', NVL_IMAGE_CONTEXT, 1) for t in text
|
||||
]
|
||||
text = [t.replace("<NVL_IMG_CONTEXT>", IMG_CONTEXT) for t in text]
|
||||
return text, image_inputs
|
||||
|
||||
def get_image_repl(
|
||||
self,
|
||||
feature_size: int,
|
||||
num_patches: Optional[int],
|
||||
) -> PromptUpdateDetails[str]:
|
||||
repl_features = IMG_CONTEXT * feature_size
|
||||
repl_full = IMG_START + repl_features + IMG_END
|
||||
|
||||
return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
|
||||
|
||||
|
||||
class NemotronVLProcessingInfo(BaseInternVLProcessingInfo):
|
||||
"""Processing info for Nemotron VL models."""
|
||||
|
||||
def get_hf_processor(
|
||||
self,
|
||||
*,
|
||||
min_dynamic_patch: Optional[int] = None,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
**kwargs: object,
|
||||
) -> NemotronVLProcessor:
|
||||
if min_dynamic_patch is not None:
|
||||
kwargs["min_dynamic_patch"] = min_dynamic_patch
|
||||
if max_dynamic_patch is not None:
|
||||
kwargs["max_dynamic_patch"] = max_dynamic_patch
|
||||
if dynamic_image_size is not None:
|
||||
kwargs["dynamic_image_size"] = dynamic_image_size
|
||||
|
||||
image_processor = self.get_image_processor()
|
||||
return self.ctx.init_processor(
|
||||
NemotronVLProcessor,
|
||||
config=self.get_hf_config(),
|
||||
tokenizer=self.get_tokenizer(),
|
||||
image_processor=image_processor,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_image_processor(
|
||||
self,
|
||||
**kwargs: object,
|
||||
):
|
||||
return cached_image_processor_from_config(
|
||||
self.ctx.model_config,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
BaseInternVLMultiModalProcessor[NemotronVLProcessingInfo],
|
||||
info=NemotronVLProcessingInfo,
|
||||
dummy_inputs=BaseInternVLDummyInputsBuilder[NemotronVLProcessingInfo])
|
||||
class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
SupportsLoRA):
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return "<image>"
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
self._patch_quant_config(config, quant_config)
|
||||
|
||||
image_size = config.force_image_size or config.vision_config.image_size
|
||||
patch_size = config.vision_config.patch_size
|
||||
self.patch_size = patch_size
|
||||
self.num_image_token = int(
|
||||
(image_size // patch_size)**2 * (config.downsample_ratio**2))
|
||||
self.downsample_ratio = config.downsample_ratio
|
||||
self.ps_version = config.ps_version
|
||||
|
||||
self.llm_arch_name = config.text_config.architectures[0]
|
||||
self.vision_model = self._init_vision_model(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "vision_model"),
|
||||
)
|
||||
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
|
||||
self.mlp1 = self._init_mlp1(config)
|
||||
|
||||
self.img_context_token_id = None
|
||||
|
||||
self.visual_token_mask = None
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
def _patch_quant_config(self, config: PretrainedConfig,
|
||||
quant_config: QuantizationConfig):
|
||||
# the awq models from OpenGVLab missing `modules_to_not_convert`
|
||||
# patch the quant_config to add `modules_to_not_convert` back
|
||||
if isinstance(quant_config, AWQConfig):
|
||||
text_config = config.text_config
|
||||
llm_quant_config = getattr(text_config, "quantization_config",
|
||||
None)
|
||||
if (not quant_config.modules_to_not_convert) and \
|
||||
(llm_quant_config is not None):
|
||||
quant_config.modules_to_not_convert.append("vision_model")
|
||||
|
||||
def _init_vision_model(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
*,
|
||||
prefix: str,
|
||||
):
|
||||
return AutoModel.from_config(config.vision_config,
|
||||
trust_remote_code=True)
|
||||
|
||||
def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
|
||||
vit_hidden_size = config.vit_hidden_size
|
||||
vision_projection_hidden_size = config.projector_hidden_size
|
||||
llm_hidden_size = config.text_config.hidden_size
|
||||
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2,
|
||||
bias=True),
|
||||
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,
|
||||
vision_projection_hidden_size,
|
||||
bias=True),
|
||||
nn.GELU(),
|
||||
nn.Linear(vision_projection_hidden_size, llm_hidden_size),
|
||||
)
|
||||
|
||||
def pixel_shuffle(self, x, scale_factor=0.5):
|
||||
n, w, h, c = x.size()
|
||||
# N, W, H, C --> N, W, H * scale, C // scale
|
||||
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
|
||||
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
|
||||
x = x.permute(0, 2, 1, 3).contiguous()
|
||||
x = x.view(n, int(h * scale_factor), int(w * scale_factor),
|
||||
int(c / (scale_factor * scale_factor)))
|
||||
if self.ps_version == 'v1':
|
||||
pass
|
||||
else:
|
||||
x = x.permute(0, 2, 1, 3).contiguous()
|
||||
return x
|
||||
|
||||
def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
# https://huggingface.co/nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1/blob/main/modeling.py#L177
|
||||
vit_embeds = self.vision_model(x=pixel_values).features
|
||||
vit_embeds = vit_embeds.to(dtype=torch.bfloat16)
|
||||
|
||||
h = w = int(vit_embeds.shape[1]**0.5)
|
||||
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
|
||||
vit_embeds = self.pixel_shuffle(vit_embeds,
|
||||
scale_factor=self.downsample_ratio)
|
||||
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1,
|
||||
vit_embeds.shape[-1])
|
||||
vit_embeds = self.mlp1(vit_embeds)
|
||||
return vit_embeds
|
||||
|
||||
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
#use force_image_size to get image_size
|
||||
h = w = self.config.force_image_size
|
||||
expected_dims = (3, h, w)
|
||||
|
||||
def _validate_shape(d: torch.Tensor):
|
||||
actual_dims = tuple(d.shape)
|
||||
|
||||
if actual_dims != expected_dims:
|
||||
expected_expr = str(expected_dims)
|
||||
raise ValueError(
|
||||
"The expected shape of pixel values per image per batch "
|
||||
f" per patch is {expected_expr}. "
|
||||
f"You supplied {tuple(d.shape)}.")
|
||||
|
||||
for d in data:
|
||||
_validate_shape(d)
|
||||
|
||||
return data
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[InternVLImageInputs]:
|
||||
pixel_values_flat = kwargs.pop("pixel_values_flat", None)
|
||||
image_num_patches = kwargs.pop("image_num_patches", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
|
||||
if pixel_values_flat is None and image_embeds is None:
|
||||
return None
|
||||
|
||||
if image_embeds is not None:
|
||||
if not isinstance(image_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
|
||||
return InternVLImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
data=flatten_bn(image_embeds),
|
||||
)
|
||||
|
||||
image_token_id = kwargs["image_token_id"]
|
||||
assert isinstance(image_token_id, torch.Tensor)
|
||||
self.img_context_token_id = image_token_id.flatten().unique().item()
|
||||
|
||||
if pixel_values_flat is not None:
|
||||
if not isinstance(pixel_values_flat, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values_flat)}")
|
||||
|
||||
if not isinstance(image_num_patches, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image_num_patches. "
|
||||
f"Got type: {type(image_num_patches)}")
|
||||
|
||||
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
|
||||
image_num_patches = flatten_bn(image_num_patches, concat=True)
|
||||
|
||||
return InternVLImagePixelInputs(
|
||||
type="pixel_values",
|
||||
pixel_values_flat=self._validate_pixel_values(
|
||||
pixel_values_flat),
|
||||
num_patches=image_num_patches,
|
||||
)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _process_image_input(
|
||||
self,
|
||||
image_input: InternVLImageInputs,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
if image_input["type"] == "image_embeds":
|
||||
return image_input["data"]
|
||||
|
||||
assert self.vision_model is not None
|
||||
|
||||
image_embeds = self.extract_feature(image_input["pixel_values_flat"])
|
||||
|
||||
num_patches = image_input["num_patches"]
|
||||
|
||||
# Only one image in the current batch
|
||||
if len(num_patches) == 1:
|
||||
return (image_embeds.view(-1,
|
||||
self.config.text_config.hidden_size), )
|
||||
|
||||
# NOTE: Image embeddings are split into separate tensors for each image
|
||||
# by the size of each embedding.
|
||||
feature_size = image_embeds.shape[1]
|
||||
image_embeds = image_embeds.view(-1,
|
||||
self.config.text_config.hidden_size)
|
||||
image_feature_sizes = [
|
||||
num_patches * feature_size for num_patches in num_patches
|
||||
]
|
||||
return image_embeds.split(image_feature_sizes)
|
||||
|
||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||
modalities = {}
|
||||
|
||||
# Preserve the order of modalities if there are multiple of them
|
||||
# from the order of kwargs.
|
||||
for input_key in kwargs:
|
||||
if input_key in ("pixel_values_flat",
|
||||
"image_embeds") and "images" not in modalities:
|
||||
modalities["images"] = self._parse_and_validate_image_input(
|
||||
**kwargs)
|
||||
|
||||
return modalities
|
||||
|
||||
def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
|
||||
self.visual_token_mask = None
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(self,
|
||||
**kwargs: object) -> MultiModalEmbeddings:
|
||||
|
||||
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||
if not modalities:
|
||||
return []
|
||||
|
||||
# The result multimodal_embeddings is tuple of tensors, with each
|
||||
# tensor correspoending to a multimodal data item (image).
|
||||
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
|
||||
|
||||
# NOTE: It is important to iterate over the keys in this dictionary
|
||||
# to preserve the order of the modalities.
|
||||
for modality in modalities:
|
||||
if modality == "images":
|
||||
image_input = modalities["images"]
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
multimodal_embeddings += vision_embeddings
|
||||
|
||||
return multimodal_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
context_token_ids = [self.img_context_token_id]
|
||||
assert len(context_token_ids) >= 1
|
||||
self._set_visual_token_mask(input_ids)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
context_token_ids,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
) -> IntermediateTensors:
|
||||
|
||||
if intermediate_tensors is not None:
|
||||
input_ids = None
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
input_ids = None
|
||||
|
||||
forward_kwargs = {
|
||||
"input_ids": input_ids,
|
||||
"positions": positions,
|
||||
"intermediate_tensors": intermediate_tensors,
|
||||
"inputs_embeds": inputs_embeds,
|
||||
}
|
||||
|
||||
# Only required if the model is mono-architecture
|
||||
if self.visual_token_mask is not None:
|
||||
forward_kwargs.update(
|
||||
{"visual_token_mask": self.visual_token_mask})
|
||||
self.visual_token_mask = None
|
||||
|
||||
hidden_states = self.language_model.model(**forward_kwargs)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
return self.language_model.compute_logits(hidden_states,
|
||||
sampling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
## Ignore registered_buffers
|
||||
## see https://huggingface.co/nvidia/C-RADIOv2-H/blob/main/input_conditioner.py#L28 # noqa: E501
|
||||
skip_substrs = ["norm_mean", "norm_std"]
|
||||
loader = AutoWeightsLoader(self, skip_substrs=skip_substrs)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""
|
||||
Get the module prefix in multimodal models
|
||||
"""
|
||||
return MultiModelKeys.from_string_field(
|
||||
language_model="language_model",
|
||||
connector="mlp1",
|
||||
tower_model="vision_model")
|
||||
@ -206,6 +206,7 @@ _MULTIMODAL_MODELS = {
|
||||
"SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"), # noqa: E501
|
||||
"KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
|
||||
"KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501
|
||||
"Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
|
||||
"LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
|
||||
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
|
||||
"LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501
|
||||
|
||||
@ -202,4 +202,4 @@ class NemotronConfig(PretrainedConfig):
|
||||
rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
||||
raise ValueError(
|
||||
"`rope_scaling`'s factor field must be a float > 1, got "
|
||||
f"{rope_scaling_factor}")
|
||||
f"{rope_scaling_factor}")
|
||||
Loading…
x
Reference in New Issue
Block a user