[Model] Support NVLM-D and fix QK Norm in InternViT (#9045)

Co-authored-by: Roger Wang <ywang@roblox.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Cyrus Leung 2024-10-07 19:55:12 +08:00 committed by GitHub
parent f19da64871
commit 151ef4efd2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 518 additions and 236 deletions

View File

@ -315,6 +315,9 @@ Multimodal Language Models
.. _supported_vlms:
Text Generation
---------------
.. list-table::
:widths: 25 25 25 25 5 5
:header-rows: 1
@ -384,7 +387,13 @@ Multimodal Language Models
- Image
- :code:`meta-llama/Llama-3.2-90B-Vision-Instruct`, :code:`meta-llama/Llama-3.2-11B-Vision`, etc.
-
-
* - :code:`NVLM_D_Model`
- NVLM-D 1.0
- Image\ :sup:`E+`
- :code:`nvidia/NVLM-D-72B`, etc.
-
- ✅︎
* - :code:`PaliGemmaForConditionalGeneration`
- PaliGemma
- Image\ :sup:`E`

View File

@ -18,7 +18,7 @@ from vllm.utils import FlexibleArgumentParser
# LLaVA-1.5
def run_llava(question, modality):
def run_llava(question: str, modality: str):
assert modality == "image"
prompt = f"USER: <image>\n{question}\nASSISTANT:"
@ -29,7 +29,7 @@ def run_llava(question, modality):
# LLaVA-1.6/LLaVA-NeXT
def run_llava_next(question, modality):
def run_llava_next(question: str, modality: str):
assert modality == "image"
prompt = f"[INST] <image>\n{question} [/INST]"
@ -40,7 +40,7 @@ def run_llava_next(question, modality):
# LlaVA-NeXT-Video
# Currently only support for video input
def run_llava_next_video(question, modality):
def run_llava_next_video(question: str, modality: str):
assert modality == "video"
prompt = f"USER: <video>\n{question} ASSISTANT:"
@ -50,7 +50,7 @@ def run_llava_next_video(question, modality):
# LLaVA-OneVision
def run_llava_onevision(question, modality):
def run_llava_onevision(question: str, modality: str):
if modality == "video":
prompt = f"<|im_start|>user <video>\n{question}<|im_end|> \
@ -67,7 +67,7 @@ def run_llava_onevision(question, modality):
# Fuyu
def run_fuyu(question, modality):
def run_fuyu(question: str, modality: str):
assert modality == "image"
prompt = f"{question}\n"
@ -77,7 +77,7 @@ def run_fuyu(question, modality):
# Phi-3-Vision
def run_phi3v(question, modality):
def run_phi3v(question: str, modality: str):
assert modality == "image"
prompt = f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n" # noqa: E501
@ -112,7 +112,7 @@ def run_phi3v(question, modality):
# PaliGemma
def run_paligemma(question, modality):
def run_paligemma(question: str, modality: str):
assert modality == "image"
# PaliGemma has special prompt format for VQA
@ -123,7 +123,7 @@ def run_paligemma(question, modality):
# Chameleon
def run_chameleon(question, modality):
def run_chameleon(question: str, modality: str):
assert modality == "image"
prompt = f"{question}<image>"
@ -133,7 +133,7 @@ def run_chameleon(question, modality):
# MiniCPM-V
def run_minicpmv(question, modality):
def run_minicpmv(question: str, modality: str):
assert modality == "image"
# 2.0
@ -176,7 +176,7 @@ def run_minicpmv(question, modality):
# InternVL
def run_internvl(question, modality):
def run_internvl(question: str, modality: str):
assert modality == "image"
model_name = "OpenGVLab/InternVL2-2B"
@ -203,8 +203,32 @@ def run_internvl(question, modality):
return llm, prompt, stop_token_ids
# NVLM-D
def run_nvlm_d(question: str, modality: str):
assert modality == "image"
model_name = "nvidia/NVLM-D-72B"
# Adjust this as necessary to fit in GPU
llm = LLM(
model=model_name,
trust_remote_code=True,
max_model_len=4096,
tensor_parallel_size=4,
)
tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
messages = [{'role': 'user', 'content': f"<image>\n{question}"}]
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
stop_token_ids = None
return llm, prompt, stop_token_ids
# BLIP-2
def run_blip2(question, modality):
def run_blip2(question: str, modality: str):
assert modality == "image"
# BLIP-2 prompt format is inaccurate on HuggingFace model repository.
@ -216,7 +240,7 @@ def run_blip2(question, modality):
# Qwen
def run_qwen_vl(question, modality):
def run_qwen_vl(question: str, modality: str):
assert modality == "image"
llm = LLM(
@ -232,7 +256,7 @@ def run_qwen_vl(question, modality):
# Qwen2-VL
def run_qwen2_vl(question, modality):
def run_qwen2_vl(question: str, modality: str):
assert modality == "image"
model_name = "Qwen/Qwen2-VL-7B-Instruct"
@ -252,8 +276,8 @@ def run_qwen2_vl(question, modality):
return llm, prompt, stop_token_ids
# LLama
def run_mllama(question, modality):
# LLama 3.2
def run_mllama(question: str, modality: str):
assert modality == "image"
model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
@ -287,6 +311,7 @@ model_example_map = {
"minicpmv": run_minicpmv,
"blip-2": run_blip2,
"internvl_chat": run_internvl,
"NVLM_D": run_nvlm_d,
"qwen_vl": run_qwen_vl,
"qwen2_vl": run_qwen2_vl,
"mllama": run_mllama,

View File

@ -144,6 +144,39 @@ def load_internvl(question: str, image_urls: List[str]) -> ModelRequestData:
)
def load_nvlm_d(question: str, image_urls: List[str]):
model_name = "nvidia/NVLM-D-72B"
# Adjust this as necessary to fit in GPU
llm = LLM(
model=model_name,
trust_remote_code=True,
max_model_len=8192,
tensor_parallel_size=4,
limit_mm_per_prompt={"image": len(image_urls)},
mm_processor_kwargs={"max_dynamic_patch": 4},
)
placeholders = "\n".join(f"Image-{i}: <image>\n"
for i, _ in enumerate(image_urls, start=1))
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
stop_token_ids = None
return ModelRequestData(
llm=llm,
prompt=prompt,
stop_token_ids=stop_token_ids,
image_data=[fetch_image(url) for url in image_urls],
chat_template=None,
)
def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
try:
from qwen_vl_utils import process_vision_info
@ -204,6 +237,7 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
model_example_map = {
"phi3_v": load_phi3v,
"internvl_chat": load_internvl,
"NVLM_D": load_nvlm_d,
"qwen2_vl": load_qwen2_vl,
"qwen_vl_chat": load_qwenvl_chat,
}

View File

@ -157,7 +157,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
if model_type.startswith("llava"):
return self._cached_token_str(self._tokenizer,
hf_config.image_token_index)
if model_type in ("chameleon", "internvl_chat"):
if model_type in ("chameleon", "internvl_chat", "NVLM_D"):
return "<image>"
if model_type == "mllama":
return "<|image|>"

View File

@ -18,10 +18,16 @@ class RMSNorm(CustomOp):
self,
hidden_size: int,
eps: float = 1e-6,
var_hidden_size: Optional[int] = None,
) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.hidden_size = hidden_size
self.variance_epsilon = eps
self.variance_size_override = (None if var_hidden_size == hidden_size
else var_hidden_size)
self.weight = nn.Parameter(torch.ones(hidden_size))
def forward_native(
self,
@ -35,7 +41,23 @@ class RMSNorm(CustomOp):
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
variance = x.pow(2).mean(dim=-1, keepdim=True)
hidden_size = x.shape[-1]
if hidden_size != self.hidden_size:
raise ValueError("Expected hidden_size to be "
f"{self.hidden_size}, but found: {hidden_size}")
if self.variance_size_override is None:
x_var = x
else:
if hidden_size < self.variance_size_override:
raise ValueError(
"Expected hidden_size to be at least "
f"{self.variance_size_override}, but found: {hidden_size}")
x_var = x[:, :, :self.variance_size_override]
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype) * self.weight
if residual is None:
@ -48,6 +70,9 @@ class RMSNorm(CustomOp):
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if self.variance_size_override is not None:
return self.forward_native(x, residual)
from vllm import _custom_ops as ops
if residual is not None:
@ -72,6 +97,9 @@ class RMSNorm(CustomOp):
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if self.variance_size_override is not None:
return self.forward_native(x, residual)
from vllm._ipex_ops import ipex_ops as ops
if residual is not None:

View File

@ -4,6 +4,7 @@
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from functools import partial
from typing import Iterable, Optional, Tuple
import torch
@ -11,7 +12,10 @@ import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -54,7 +58,7 @@ class InternVisionEmbeddings(nn.Module):
self.position_embedding = nn.Parameter(
torch.randn(1, self.num_positions, self.embed_dim))
def _get_pos_embed(self, pos_embed, H, W):
def _get_pos_embed(self, pos_embed: torch.Tensor, H: int, W: int):
target_dtype = pos_embed.dtype
pos_embed = pos_embed.float().reshape(
1, self.image_size // self.patch_size,
@ -63,9 +67,21 @@ class InternVisionEmbeddings(nn.Module):
size=(H, W),
mode='bicubic',
align_corners=False)
pos_embed = pos_embed.reshape(1, -1, H * W).permute(0, 2,
1).to(target_dtype)
return pos_embed
return pos_embed.reshape(1, -1, H * W).permute(0, 2,
1).to(target_dtype)
def _get_position_embedding(self, H: int, W: int) -> torch.Tensor:
position_embedding = self.position_embedding
if self.num_patches == H * W:
return position_embedding
return torch.cat(
[
position_embedding[:, :1, :],
self._get_pos_embed(position_embedding[:, 1:, :], H, W),
],
dim=1,
)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
target_dtype = self.patch_embedding.weight.dtype
@ -76,12 +92,7 @@ class InternVisionEmbeddings(nn.Module):
class_embeds = self.class_embedding.expand(batch_size, 1,
-1).to(target_dtype)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
position_embedding = torch.cat([
self.position_embedding[:, :1, :],
self._get_pos_embed(self.position_embedding[:, 1:, :], height,
width)
],
dim=1)
position_embedding = self._get_position_embedding(height, width)
embeddings = embeddings + position_embedding.to(target_dtype)
return embeddings
@ -93,8 +104,11 @@ class InternParallelAttention(nn.Module):
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
):
*,
num_dummy_heads: int = 0,
) -> None:
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
@ -105,11 +119,19 @@ class InternParallelAttention(nn.Module):
f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
f' {self.num_heads}).')
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
# Additional dummy heads are used to enable TP for common GPU counts.
self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim
self.num_heads_per_partition = divide(num_dummy_heads + self.num_heads,
self.tp_size)
self.scale = self.head_dim**-0.5
self.qkv = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.num_heads,
num_dummy_heads + self.num_heads,
bias=config.qkv_bias,
quant_config=quant_config,
)
@ -117,34 +139,44 @@ class InternParallelAttention(nn.Module):
self.qk_normalization = config.qk_normalization
if self.qk_normalization:
self.q_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
self.k_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
self.q_norm = RMSNorm(self.dummy_dim,
eps=config.layer_norm_eps,
var_hidden_size=self.embed_dim)
self.k_norm = RMSNorm(self.dummy_dim,
eps=config.layer_norm_eps,
var_hidden_size=self.embed_dim)
self.proj = RowParallelLinear(
self.embed_dim,
self.dummy_dim,
self.embed_dim,
quant_config=quant_config,
)
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
if self.tp_size > 1:
q = tensor_model_parallel_all_gather(q.contiguous())
k = tensor_model_parallel_all_gather(k.contiguous())
q = self.q_norm.forward_native(q)
k = self.k_norm.forward_native(k)
if self.tp_size > 1:
splitter = partial(split_tensor_along_last_dim,
num_partitions=self.tp_size)
q = splitter(q)[self.tp_rank]
k = splitter(k)[self.tp_rank]
return q, k
def forward(self, x):
B, N, C = x.shape
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, _ = x.shape
qkv, _ = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
if self.qk_normalization:
q, k = self._apply_qk_norm(q, k)
q = q.view(B, N, self.num_heads_per_partition, self.head_dim)
k = k.view(B, N, self.num_heads_per_partition, self.head_dim)
v = v.view(B, N, self.num_heads_per_partition, self.head_dim)
if self.qk_normalization:
B_, N_, H_, D_ = q.shape
q = self.q_norm.forward_native(q.flatten(-2,
-1)).view(B_, N_, H_, D_)
k = self.k_norm.forward_native(k.flatten(-2,
-1)).view(B_, N_, H_, D_)
x = xops.memory_efficient_attention_forward(q, k, v, scale=self.scale)
x = x.view(B, N, -1)
@ -155,8 +187,14 @@ class InternParallelAttention(nn.Module):
class InternSdpaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: PretrainedConfig):
def __init__(
self,
config: PretrainedConfig,
*,
num_dummy_heads: int = 0,
) -> None:
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
@ -167,20 +205,27 @@ class InternSdpaAttention(nn.Module):
f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
f' {self.num_heads}).')
# Additional dummy heads are used to enable TP for common GPU counts.
self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim
self.scale = self.head_dim**-0.5
self.qkv = nn.Linear(self.embed_dim,
3 * self.embed_dim,
3 * self.dummy_dim,
bias=config.qkv_bias)
self.qk_normalization = config.qk_normalization
if self.qk_normalization:
self.q_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
self.k_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
self.q_norm = RMSNorm(self.dummy_dim,
eps=config.layer_norm_eps,
var_hidden_size=self.embed_dim)
self.k_norm = RMSNorm(self.dummy_dim,
eps=config.layer_norm_eps,
var_hidden_size=self.embed_dim)
self.proj = nn.Linear(self.embed_dim, self.embed_dim)
self.proj = nn.Linear(self.dummy_dim, self.embed_dim)
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
@ -233,22 +278,23 @@ class InternMLP(nn.Module):
class InternVisionEncoderLayer(nn.Module):
def __init__(self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_dummy_heads: int = 0,
) -> None:
super().__init__()
self.embed_dim = config.hidden_size
self.intermediate_size = config.intermediate_size
self.norm_type = config.norm_type
# fallback to sdpa attention if tp unavailable
tp_size = get_tensor_model_parallel_world_size()
num_heads = config.num_attention_heads
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.attn = InternParallelAttention(config,
quant_config=quant_config)
else:
self.attn = InternSdpaAttention(config)
self.attn = self._init_attn(config,
quant_config,
num_dummy_heads=num_dummy_heads)
self.mlp = InternMLP(config, quant_config=quant_config)
self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
eps=config.layer_norm_eps)
@ -260,6 +306,24 @@ class InternVisionEncoderLayer(nn.Module):
self.ls2 = nn.Parameter(config.initializer_factor *
torch.ones(self.embed_dim))
def _init_attn(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
*,
num_dummy_heads: int,
):
# fallback to sdpa attention if tp unavailable
tp_size = get_tensor_model_parallel_world_size()
num_heads = config.num_attention_heads
if USE_XFORMERS_OPS and (num_heads + num_dummy_heads) % tp_size == 0:
return InternParallelAttention(config,
quant_config=quant_config,
num_dummy_heads=num_dummy_heads)
return InternSdpaAttention(config, num_dummy_heads=num_dummy_heads)
def forward(
self,
hidden_states: torch.Tensor,
@ -275,19 +339,27 @@ class InternVisionEncoderLayer(nn.Module):
class InternVisionEncoder(nn.Module):
def __init__(self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
num_dummy_heads: int = 0,
):
super().__init__()
self.config = config
if num_hidden_layers_override is None:
num_hidden_layers = config.num_hidden_layers
else:
num_hidden_layers = num_hidden_layers_override
self.layers = nn.ModuleList([
InternVisionEncoderLayer(config=config, quant_config=quant_config)
InternVisionEncoderLayer(config,
quant_config,
num_dummy_heads=num_dummy_heads)
for _ in range(num_hidden_layers)
])
@ -302,35 +374,25 @@ class InternVisionEncoder(nn.Module):
class InternVisionModel(nn.Module):
def __init__(self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
num_dummy_heads: int = 0,
):
super().__init__()
self.config = config
self.embeddings = InternVisionEmbeddings(config)
self.encoder = InternVisionEncoder(
config=config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override)
def resize_pos_embeddings(self, old_size, new_size, patch_size):
pos_emb = self.embeddings.position_embedding
_, num_positions, embed_dim = pos_emb.shape
cls_emb = pos_emb[:, :1, :]
pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size,
old_size // patch_size,
-1).permute(0, 3, 1, 2)
pos_emb = F.interpolate(pos_emb.float(),
size=new_size // patch_size,
mode='bicubic',
align_corners=False)
pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim,
-1).permute(0, 2, 1)
pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
self.embeddings.position_embedding = nn.Parameter(pos_emb)
self.embeddings.image_size = new_size
num_hidden_layers_override=num_hidden_layers_override,
num_dummy_heads=num_dummy_heads,
)
def get_input_embeddings(self):
return self.embeddings

View File

@ -237,130 +237,173 @@ def get_max_internvl_image_size(ctx: InputContext,
return width, height
def input_processor_for_internvl(ctx: InputContext,
llm_inputs: LLMInputs,
*,
max_dynamic_patch: Optional[int] = None):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs
class InternVLInputPipeline:
model_config = ctx.model_config
hf_config = ctx.get_hf_config()
def __init__(
self,
img_start_token: str,
img_end_token: str,
img_context_token: str,
) -> None:
super().__init__()
image_data = multi_modal_data["image"]
num_patches = get_internvl_num_patches(hf_config)
num_blocks_calculator = calculate_num_blocks_wrapper(
hf_config, max_dynamic_patch)
if isinstance(image_data, Image.Image):
width, height = image_data.size
num_blocks, _, _ = num_blocks_calculator(width, height)
image_feature_size = [num_blocks * num_patches]
elif is_list_of(image_data, Image.Image):
image_feature_size = []
for image in image_data:
width, height = image.size
self.img_start_token = img_start_token
self.img_end_token = img_end_token
self.img_context_token = img_context_token
def _create_image_prompt(self, feature_size: int, num_patches: int) -> str:
return (self.img_start_token + self.img_context_token * feature_size +
self.img_end_token)
def _expand_image_prompt(
self,
prompt: str,
feature_sizes: List[int],
num_patches: int,
) -> str:
image_idx = sorted(
map(int, re.findall(r"Image-(\d+): <image>\n", prompt)))
new_prompt = prompt
for idx, feature_size in enumerate(feature_sizes, start=1):
image_prompt = self._create_image_prompt(feature_size, num_patches)
if not image_idx:
image_prompt = f"Image-{idx}: {image_prompt}"
new_prompt = new_prompt.replace('<image>', image_prompt, 1)
return new_prompt
def input_processor(
self,
ctx: InputContext,
llm_inputs: LLMInputs,
*,
max_dynamic_patch: Optional[int] = None,
) -> LLMInputs:
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs
model_config = ctx.model_config
hf_config = ctx.get_hf_config()
image_data = multi_modal_data["image"]
num_patches = get_internvl_num_patches(hf_config)
num_blocks_calculator = calculate_num_blocks_wrapper(
hf_config, max_dynamic_patch)
if isinstance(image_data, Image.Image):
width, height = image_data.size
num_blocks, _, _ = num_blocks_calculator(width, height)
image_feature_size.append(num_blocks * num_patches)
elif isinstance(image_data, torch.Tensor):
num_images, image_feature_size, hidden_size = image_data.shape
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
image_feature_sizes = [num_blocks * num_patches]
elif is_list_of(image_data, Image.Image):
image_feature_sizes = []
for image in image_data:
width, height = image.size
num_blocks, _, _ = num_blocks_calculator(width, height)
image_feature_sizes.append(num_blocks * num_patches)
elif isinstance(image_data, torch.Tensor):
num_images, image_feature_size, hidden_size = image_data.shape
image_feature_sizes = [image_feature_size]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
prompt = llm_inputs.get("prompt")
prompt_token_ids = llm_inputs["prompt_token_ids"]
if prompt is None:
prompt = tokenizer.decode(prompt_token_ids)
prompt = llm_inputs.get("prompt")
prompt_token_ids = llm_inputs["prompt_token_ids"]
if prompt is None:
prompt = tokenizer.decode(prompt_token_ids)
new_prompt = prompt
image_idx = sorted(map(int, re.findall(r"Image-(\d+): <image>\n", prompt)))
for idx, feature_size in enumerate(image_feature_size, start=1):
image_prompt = IMG_START + IMG_CONTEXT * feature_size + IMG_END
if not image_idx:
image_prompt = f"Image-{idx}: {image_prompt}"
new_prompt = new_prompt.replace('<image>', image_prompt, 1)
new_prompt_token_ids = tokenizer.encode(new_prompt)
new_prompt = self._expand_image_prompt(prompt, image_feature_sizes,
num_patches)
new_prompt_token_ids = tokenizer.encode(new_prompt)
return LLMInputs(prompt=prompt,
prompt_token_ids=new_prompt_token_ids,
multi_modal_data=multi_modal_data)
return LLMInputs(prompt=prompt,
prompt_token_ids=new_prompt_token_ids,
multi_modal_data=multi_modal_data)
def input_mapper(
self,
ctx: InputContext,
data: object,
*,
max_dynamic_patch: Optional[int] = None,
):
hf_config = ctx.get_hf_config()
image_pixel_values_mapper = image_to_pixel_values_wrapper(
hf_config, max_dynamic_patch)
if isinstance(data, Image.Image):
data = image_pixel_values_mapper(data)
# Add an N dimension for number of images per prompt (currently 1).
data = data.unsqueeze(0)
elif is_list_of(data, Image.Image):
# we can't stack here because images may have different num_patches
data = [image_pixel_values_mapper(img) for img in data]
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
image_token_id = tokenizer.encode(self.img_context_token,
add_special_tokens=False,
return_tensors="pt")[0]
return MultiModalInputs({
"pixel_values": data,
"image_token_id": image_token_id
})
def dummy_data(
self,
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
*,
max_dynamic_patch: Optional[int] = None,
):
num_images = mm_counts["image"]
hf_config = ctx.get_hf_config()
image_feature_size = get_max_internvl_image_tokens(
ctx, max_dynamic_patch=max_dynamic_patch)
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
seq_data = dummy_seq_data_for_clip(
hf_config.vision_config,
seq_len,
num_images,
image_token_id=tokenizer.encode(self.img_context_token,
add_special_tokens=False)[0],
image_feature_size_override=image_feature_size,
)
max_image_width, max_image_height = get_max_internvl_image_size(
ctx, max_dynamic_patch=max_dynamic_patch)
mm_data = dummy_image_for_clip(
hf_config.vision_config,
num_images,
image_width_override=max_image_width,
image_height_override=max_image_height,
)
return seq_data, mm_data
def input_mapper_for_internvl(ctx: InputContext,
data: object,
*,
max_dynamic_patch: Optional[int] = None):
hf_config = ctx.get_hf_config()
image_pixel_values_mapper = image_to_pixel_values_wrapper(
hf_config, max_dynamic_patch)
if isinstance(data, Image.Image):
data = image_pixel_values_mapper(data)
# Add an N dimension for number of images per prompt (currently 1).
data = data.unsqueeze(0)
elif is_list_of(data, Image.Image):
# we can't stack here because the images may have different num_patches
data = [image_pixel_values_mapper(img) for img in data]
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
image_token_id = tokenizer.encode(IMG_CONTEXT,
add_special_tokens=False,
return_tensors="pt")[0]
return MultiModalInputs({
"pixel_values": data,
"image_token_id": image_token_id
})
input_pipeline = InternVLInputPipeline(IMG_START, IMG_END, IMG_CONTEXT)
def dummy_data_for_internvl(ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
*,
max_dynamic_patch: Optional[int] = None):
num_images = mm_counts["image"]
hf_config = ctx.get_hf_config()
image_feature_size = get_max_internvl_image_tokens(
ctx, max_dynamic_patch=max_dynamic_patch)
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
seq_data = dummy_seq_data_for_clip(
hf_config.vision_config,
seq_len,
num_images,
image_token_id=tokenizer.encode(IMG_CONTEXT,
add_special_tokens=False)[0],
image_feature_size_override=image_feature_size,
)
max_image_width, max_image_height = get_max_internvl_image_size(
ctx, max_dynamic_patch=max_dynamic_patch)
mm_data = dummy_image_for_clip(
hf_config.vision_config,
num_images,
image_width_override=max_image_width,
image_height_override=max_image_height,
)
return seq_data, mm_data
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_internvl)
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_pipeline.input_mapper)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_internvl)
@INPUT_REGISTRY.register_input_processor(input_processor_for_internvl)
@INPUT_REGISTRY.register_dummy_data(input_pipeline.dummy_data)
@INPUT_REGISTRY.register_input_processor(input_pipeline.input_processor)
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self,
@ -388,20 +431,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
+ vision_feature_layer + 1
else:
num_hidden_layers = vision_feature_layer + 1
self.vision_model = InternVisionModel(
config.vision_config, num_hidden_layers_override=num_hidden_layers)
self.vision_model = self._init_vision_model(config, num_hidden_layers)
self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config)
vit_hidden_size = config.vision_config.hidden_size
llm_hidden_size = config.text_config.hidden_size
self.mlp1 = nn.Sequential(
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2),
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,
llm_hidden_size), nn.GELU(),
nn.Linear(llm_hidden_size, llm_hidden_size))
self.mlp1 = self._init_mlp1(config)
self.img_context_token_id = None
self.make_empty_intermediate_tensors = (
@ -414,6 +449,23 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
return Sampler()
def _init_vision_model(self, config: PretrainedConfig,
num_hidden_layers: int):
return InternVisionModel(config.vision_config,
num_hidden_layers_override=num_hidden_layers)
def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
vit_hidden_size = config.vision_config.hidden_size
llm_hidden_size = config.text_config.hidden_size
return nn.Sequential(
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2),
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,
llm_hidden_size),
nn.GELU(),
nn.Linear(llm_hidden_size, llm_hidden_size),
)
def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size()
# N, W, H, C --> N, W, H * scale, C // scale

View File

@ -0,0 +1,64 @@
# adapted from https://huggingface.co/nvidia/NVLM-D-72B/blob/main/modeling_nvlm_d.py
# --------------------------------------------------------
# NVLM-D
# Copyright (c) 2024 NVIDIA
# Licensed under Apache 2.0 License [see LICENSE for details]
# --------------------------------------------------------
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.inputs import INPUT_REGISTRY
from vllm.multimodal import MULTIMODAL_REGISTRY
from .intern_vit import InternVisionModel
from .internvl import (InternVLChatModel, InternVLInputPipeline,
get_max_internvl_image_tokens)
IMG_START = '<|vision_start|>'
IMG_END = '<|vision_end|>'
IMG_CONTEXT = '<|vision_pad|>'
class NVLMInputPipeline(InternVLInputPipeline):
def _create_image_prompt(self, feature_size: int, num_patches: int) -> str:
tile_pos_identifiers = ([f"<tile_{i}>"
for i in range(1, num_patches)] +
["<tile_global_thumbnail>"])
context_size = feature_size // num_patches
return '<Image>' + ''.join(
tile_pos_identifier + self.img_context_token * context_size
for tile_pos_identifier in tile_pos_identifiers) + '</Image>'
input_pipeline = NVLMInputPipeline(IMG_START, IMG_END, IMG_CONTEXT)
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_pipeline.input_mapper)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
@INPUT_REGISTRY.register_dummy_data(input_pipeline.dummy_data)
@INPUT_REGISTRY.register_input_processor(input_pipeline.input_processor)
class NVLM_D_Model(InternVLChatModel):
def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
vit_hidden_size = config.vision_config.hidden_size
llm_intermediate_size = config.text_config.intermediate_size
llm_hidden_size = config.text_config.hidden_size
return nn.Sequential(
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2),
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,
llm_intermediate_size,
bias=False),
nn.GELU(),
nn.Linear(llm_intermediate_size, llm_hidden_size, bias=False),
)
def _init_vision_model(self, config: PretrainedConfig,
num_hidden_layers: int):
# We added additional dummy heads to the original num of heads to make
# the number of heads divisible by 8.
return InternVisionModel(config.vision_config,
num_hidden_layers_override=num_hidden_layers,
num_dummy_heads=7)

View File

@ -16,6 +16,7 @@ from .interfaces_base import is_embedding_model, is_text_generation_model
logger = init_logger(__name__)
# yapf: disable
_TEXT_GENERATION_MODELS = {
# [Decoder-only]
"AquilaModel": ("llama", "LlamaForCausalLM"),
@ -68,8 +69,6 @@ _TEXT_GENERATION_MODELS = {
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
"Qwen2VLForConditionalGeneration":
("qwen2_vl", "Qwen2VLForConditionalGeneration"),
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
@ -88,32 +87,25 @@ _EMBEDDING_MODELS = {
}
_MULTIMODAL_MODELS = {
"Blip2ForConditionalGeneration":
("blip2", "Blip2ForConditionalGeneration"),
"ChameleonForConditionalGeneration":
("chameleon", "ChameleonForConditionalGeneration"),
# [Decoder-only]
"Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
"ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
"InternVLChatModel": ("internvl", "InternVLChatModel"),
"LlavaForConditionalGeneration": ("llava",
"LlavaForConditionalGeneration"),
"LlavaNextForConditionalGeneration": ("llava_next",
"LlavaNextForConditionalGeneration"),
"LlavaNextVideoForConditionalGeneration":
("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
"LlavaOnevisionForConditionalGeneration":
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
"LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
"LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501
"LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), # noqa: E501
"MiniCPMV": ("minicpmv", "MiniCPMV"),
"PaliGemmaForConditionalGeneration": ("paligemma",
"PaliGemmaForConditionalGeneration"),
"NVLM_D": ("nvlm_d", "NVLM_D_Model"),
"PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"PixtralForConditionalGeneration": ("pixtral",
"PixtralForConditionalGeneration"),
"PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2VLForConditionalGeneration": ("qwen2_vl",
"Qwen2VLForConditionalGeneration"),
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
"UltravoxModel": ("ultravox", "UltravoxModel"),
"MllamaForConditionalGeneration": ("mllama",
"MllamaForConditionalGeneration"),
# [Encoder-decoder]
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
}
_SPECULATIVE_DECODING_MODELS = {
@ -121,6 +113,7 @@ _SPECULATIVE_DECODING_MODELS = {
"MedusaModel": ("medusa", "Medusa"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
}
# yapf: enable
_MODELS = {
**_TEXT_GENERATION_MODELS,

View File

@ -22,9 +22,9 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
InternVLChatConfig, JAISConfig,
MedusaConfig, MllamaConfig,
MLPSpeculatorConfig, MPTConfig,
NemotronConfig, Qwen2VLConfig,
RWConfig, SolarConfig,
UltravoxConfig)
NemotronConfig, NVLM_D_Config,
Qwen2VLConfig, RWConfig,
SolarConfig, UltravoxConfig)
# yapf: enable
from vllm.transformers_utils.utils import check_gguf_file
@ -54,6 +54,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"exaone": ExaoneConfig,
"internvl_chat": InternVLChatConfig,
"nemotron": NemotronConfig,
"NVLM_D": NVLM_D_Config,
"solar": SolarConfig,
"ultravox": UltravoxConfig,
"qwen2_vl": Qwen2VLConfig,

View File

@ -13,6 +13,7 @@ from vllm.transformers_utils.configs.mllama import MllamaConfig
from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
from vllm.transformers_utils.configs.mpt import MPTConfig
from vllm.transformers_utils.configs.nemotron import NemotronConfig
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
from vllm.transformers_utils.configs.qwen2vl import (Qwen2VLConfig,
Qwen2VLVisionConfig)
from vllm.transformers_utils.configs.solar import SolarConfig
@ -31,6 +32,7 @@ __all__ = [
"MllamaConfig",
"MLPSpeculatorConfig",
"NemotronConfig",
"NVLM_D_Config",
"SolarConfig",
"UltravoxConfig",
"Qwen2VLConfig",

View File

@ -0,0 +1,12 @@
# Adapted from
# https://huggingface.co/nvidia/NVLM-D-72B/blob/main/configuration_nvlm_d.py
# --------------------------------------------------------
# NVLM-D
# Copyright (c) 2024 NVIDIA
# Licensed under Apache 2.0 License [see LICENSE for details]
# --------------------------------------------------------
from .internvl import InternVLChatConfig
class NVLM_D_Config(InternVLChatConfig):
model_type = 'NVLM_D'