mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-24 16:54:31 +08:00
[Model] Support NVLM-D and fix QK Norm in InternViT (#9045)
Co-authored-by: Roger Wang <ywang@roblox.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
f19da64871
commit
151ef4efd2
@ -315,6 +315,9 @@ Multimodal Language Models
|
||||
|
||||
.. _supported_vlms:
|
||||
|
||||
Text Generation
|
||||
---------------
|
||||
|
||||
.. list-table::
|
||||
:widths: 25 25 25 25 5 5
|
||||
:header-rows: 1
|
||||
@ -384,7 +387,13 @@ Multimodal Language Models
|
||||
- Image
|
||||
- :code:`meta-llama/Llama-3.2-90B-Vision-Instruct`, :code:`meta-llama/Llama-3.2-11B-Vision`, etc.
|
||||
-
|
||||
-
|
||||
* - :code:`NVLM_D_Model`
|
||||
- NVLM-D 1.0
|
||||
- Image\ :sup:`E+`
|
||||
- :code:`nvidia/NVLM-D-72B`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`PaliGemmaForConditionalGeneration`
|
||||
- PaliGemma
|
||||
- Image\ :sup:`E`
|
||||
|
||||
@ -18,7 +18,7 @@ from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
# LLaVA-1.5
|
||||
def run_llava(question, modality):
|
||||
def run_llava(question: str, modality: str):
|
||||
assert modality == "image"
|
||||
|
||||
prompt = f"USER: <image>\n{question}\nASSISTANT:"
|
||||
@ -29,7 +29,7 @@ def run_llava(question, modality):
|
||||
|
||||
|
||||
# LLaVA-1.6/LLaVA-NeXT
|
||||
def run_llava_next(question, modality):
|
||||
def run_llava_next(question: str, modality: str):
|
||||
assert modality == "image"
|
||||
|
||||
prompt = f"[INST] <image>\n{question} [/INST]"
|
||||
@ -40,7 +40,7 @@ def run_llava_next(question, modality):
|
||||
|
||||
# LlaVA-NeXT-Video
|
||||
# Currently only support for video input
|
||||
def run_llava_next_video(question, modality):
|
||||
def run_llava_next_video(question: str, modality: str):
|
||||
assert modality == "video"
|
||||
|
||||
prompt = f"USER: <video>\n{question} ASSISTANT:"
|
||||
@ -50,7 +50,7 @@ def run_llava_next_video(question, modality):
|
||||
|
||||
|
||||
# LLaVA-OneVision
|
||||
def run_llava_onevision(question, modality):
|
||||
def run_llava_onevision(question: str, modality: str):
|
||||
|
||||
if modality == "video":
|
||||
prompt = f"<|im_start|>user <video>\n{question}<|im_end|> \
|
||||
@ -67,7 +67,7 @@ def run_llava_onevision(question, modality):
|
||||
|
||||
|
||||
# Fuyu
|
||||
def run_fuyu(question, modality):
|
||||
def run_fuyu(question: str, modality: str):
|
||||
assert modality == "image"
|
||||
|
||||
prompt = f"{question}\n"
|
||||
@ -77,7 +77,7 @@ def run_fuyu(question, modality):
|
||||
|
||||
|
||||
# Phi-3-Vision
|
||||
def run_phi3v(question, modality):
|
||||
def run_phi3v(question: str, modality: str):
|
||||
assert modality == "image"
|
||||
|
||||
prompt = f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n" # noqa: E501
|
||||
@ -112,7 +112,7 @@ def run_phi3v(question, modality):
|
||||
|
||||
|
||||
# PaliGemma
|
||||
def run_paligemma(question, modality):
|
||||
def run_paligemma(question: str, modality: str):
|
||||
assert modality == "image"
|
||||
|
||||
# PaliGemma has special prompt format for VQA
|
||||
@ -123,7 +123,7 @@ def run_paligemma(question, modality):
|
||||
|
||||
|
||||
# Chameleon
|
||||
def run_chameleon(question, modality):
|
||||
def run_chameleon(question: str, modality: str):
|
||||
assert modality == "image"
|
||||
|
||||
prompt = f"{question}<image>"
|
||||
@ -133,7 +133,7 @@ def run_chameleon(question, modality):
|
||||
|
||||
|
||||
# MiniCPM-V
|
||||
def run_minicpmv(question, modality):
|
||||
def run_minicpmv(question: str, modality: str):
|
||||
assert modality == "image"
|
||||
|
||||
# 2.0
|
||||
@ -176,7 +176,7 @@ def run_minicpmv(question, modality):
|
||||
|
||||
|
||||
# InternVL
|
||||
def run_internvl(question, modality):
|
||||
def run_internvl(question: str, modality: str):
|
||||
assert modality == "image"
|
||||
|
||||
model_name = "OpenGVLab/InternVL2-2B"
|
||||
@ -203,8 +203,32 @@ def run_internvl(question, modality):
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
|
||||
# NVLM-D
|
||||
def run_nvlm_d(question: str, modality: str):
|
||||
assert modality == "image"
|
||||
|
||||
model_name = "nvidia/NVLM-D-72B"
|
||||
|
||||
# Adjust this as necessary to fit in GPU
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_model_len=4096,
|
||||
tensor_parallel_size=4,
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||
trust_remote_code=True)
|
||||
messages = [{'role': 'user', 'content': f"<image>\n{question}"}]
|
||||
prompt = tokenizer.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
stop_token_ids = None
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
|
||||
# BLIP-2
|
||||
def run_blip2(question, modality):
|
||||
def run_blip2(question: str, modality: str):
|
||||
assert modality == "image"
|
||||
|
||||
# BLIP-2 prompt format is inaccurate on HuggingFace model repository.
|
||||
@ -216,7 +240,7 @@ def run_blip2(question, modality):
|
||||
|
||||
|
||||
# Qwen
|
||||
def run_qwen_vl(question, modality):
|
||||
def run_qwen_vl(question: str, modality: str):
|
||||
assert modality == "image"
|
||||
|
||||
llm = LLM(
|
||||
@ -232,7 +256,7 @@ def run_qwen_vl(question, modality):
|
||||
|
||||
|
||||
# Qwen2-VL
|
||||
def run_qwen2_vl(question, modality):
|
||||
def run_qwen2_vl(question: str, modality: str):
|
||||
assert modality == "image"
|
||||
|
||||
model_name = "Qwen/Qwen2-VL-7B-Instruct"
|
||||
@ -252,8 +276,8 @@ def run_qwen2_vl(question, modality):
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
|
||||
# LLama
|
||||
def run_mllama(question, modality):
|
||||
# LLama 3.2
|
||||
def run_mllama(question: str, modality: str):
|
||||
assert modality == "image"
|
||||
|
||||
model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
||||
@ -287,6 +311,7 @@ model_example_map = {
|
||||
"minicpmv": run_minicpmv,
|
||||
"blip-2": run_blip2,
|
||||
"internvl_chat": run_internvl,
|
||||
"NVLM_D": run_nvlm_d,
|
||||
"qwen_vl": run_qwen_vl,
|
||||
"qwen2_vl": run_qwen2_vl,
|
||||
"mllama": run_mllama,
|
||||
|
||||
@ -144,6 +144,39 @@ def load_internvl(question: str, image_urls: List[str]) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
def load_nvlm_d(question: str, image_urls: List[str]):
|
||||
model_name = "nvidia/NVLM-D-72B"
|
||||
|
||||
# Adjust this as necessary to fit in GPU
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_model_len=8192,
|
||||
tensor_parallel_size=4,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
mm_processor_kwargs={"max_dynamic_patch": 4},
|
||||
)
|
||||
|
||||
placeholders = "\n".join(f"Image-{i}: <image>\n"
|
||||
for i, _ in enumerate(image_urls, start=1))
|
||||
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||
trust_remote_code=True)
|
||||
prompt = tokenizer.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
stop_token_ids = None
|
||||
|
||||
return ModelRequestData(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
stop_token_ids=stop_token_ids,
|
||||
image_data=[fetch_image(url) for url in image_urls],
|
||||
chat_template=None,
|
||||
)
|
||||
|
||||
|
||||
def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
|
||||
try:
|
||||
from qwen_vl_utils import process_vision_info
|
||||
@ -204,6 +237,7 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
|
||||
model_example_map = {
|
||||
"phi3_v": load_phi3v,
|
||||
"internvl_chat": load_internvl,
|
||||
"NVLM_D": load_nvlm_d,
|
||||
"qwen2_vl": load_qwen2_vl,
|
||||
"qwen_vl_chat": load_qwenvl_chat,
|
||||
}
|
||||
|
||||
@ -157,7 +157,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
if model_type.startswith("llava"):
|
||||
return self._cached_token_str(self._tokenizer,
|
||||
hf_config.image_token_index)
|
||||
if model_type in ("chameleon", "internvl_chat"):
|
||||
if model_type in ("chameleon", "internvl_chat", "NVLM_D"):
|
||||
return "<image>"
|
||||
if model_type == "mllama":
|
||||
return "<|image|>"
|
||||
|
||||
@ -18,10 +18,16 @@ class RMSNorm(CustomOp):
|
||||
self,
|
||||
hidden_size: int,
|
||||
eps: float = 1e-6,
|
||||
var_hidden_size: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.variance_epsilon = eps
|
||||
self.variance_size_override = (None if var_hidden_size == hidden_size
|
||||
else var_hidden_size)
|
||||
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
@ -35,7 +41,23 @@ class RMSNorm(CustomOp):
|
||||
x = x + residual.to(torch.float32)
|
||||
residual = x.to(orig_dtype)
|
||||
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
hidden_size = x.shape[-1]
|
||||
if hidden_size != self.hidden_size:
|
||||
raise ValueError("Expected hidden_size to be "
|
||||
f"{self.hidden_size}, but found: {hidden_size}")
|
||||
|
||||
if self.variance_size_override is None:
|
||||
x_var = x
|
||||
else:
|
||||
if hidden_size < self.variance_size_override:
|
||||
raise ValueError(
|
||||
"Expected hidden_size to be at least "
|
||||
f"{self.variance_size_override}, but found: {hidden_size}")
|
||||
|
||||
x_var = x[:, :, :self.variance_size_override]
|
||||
|
||||
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
|
||||
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
x = x.to(orig_dtype) * self.weight
|
||||
if residual is None:
|
||||
@ -48,6 +70,9 @@ class RMSNorm(CustomOp):
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
if self.variance_size_override is not None:
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
if residual is not None:
|
||||
@ -72,6 +97,9 @@ class RMSNorm(CustomOp):
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
if self.variance_size_override is not None:
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
|
||||
if residual is not None:
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
# Copyright (c) 2023 OpenGVLab
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
from functools import partial
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@ -11,7 +12,10 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_gather)
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@ -54,7 +58,7 @@ class InternVisionEmbeddings(nn.Module):
|
||||
self.position_embedding = nn.Parameter(
|
||||
torch.randn(1, self.num_positions, self.embed_dim))
|
||||
|
||||
def _get_pos_embed(self, pos_embed, H, W):
|
||||
def _get_pos_embed(self, pos_embed: torch.Tensor, H: int, W: int):
|
||||
target_dtype = pos_embed.dtype
|
||||
pos_embed = pos_embed.float().reshape(
|
||||
1, self.image_size // self.patch_size,
|
||||
@ -63,9 +67,21 @@ class InternVisionEmbeddings(nn.Module):
|
||||
size=(H, W),
|
||||
mode='bicubic',
|
||||
align_corners=False)
|
||||
pos_embed = pos_embed.reshape(1, -1, H * W).permute(0, 2,
|
||||
1).to(target_dtype)
|
||||
return pos_embed
|
||||
return pos_embed.reshape(1, -1, H * W).permute(0, 2,
|
||||
1).to(target_dtype)
|
||||
|
||||
def _get_position_embedding(self, H: int, W: int) -> torch.Tensor:
|
||||
position_embedding = self.position_embedding
|
||||
if self.num_patches == H * W:
|
||||
return position_embedding
|
||||
|
||||
return torch.cat(
|
||||
[
|
||||
position_embedding[:, :1, :],
|
||||
self._get_pos_embed(position_embedding[:, 1:, :], H, W),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||
target_dtype = self.patch_embedding.weight.dtype
|
||||
@ -76,12 +92,7 @@ class InternVisionEmbeddings(nn.Module):
|
||||
class_embeds = self.class_embedding.expand(batch_size, 1,
|
||||
-1).to(target_dtype)
|
||||
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||||
position_embedding = torch.cat([
|
||||
self.position_embedding[:, :1, :],
|
||||
self._get_pos_embed(self.position_embedding[:, 1:, :], height,
|
||||
width)
|
||||
],
|
||||
dim=1)
|
||||
position_embedding = self._get_position_embedding(height, width)
|
||||
embeddings = embeddings + position_embedding.to(target_dtype)
|
||||
return embeddings
|
||||
|
||||
@ -93,8 +104,11 @@ class InternParallelAttention(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
*,
|
||||
num_dummy_heads: int = 0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
@ -105,11 +119,19 @@ class InternParallelAttention(nn.Module):
|
||||
f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
|
||||
f' {self.num_heads}).')
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
# Additional dummy heads are used to enable TP for common GPU counts.
|
||||
self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim
|
||||
self.num_heads_per_partition = divide(num_dummy_heads + self.num_heads,
|
||||
self.tp_size)
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.qkv = QKVParallelLinear(
|
||||
self.embed_dim,
|
||||
self.head_dim,
|
||||
self.num_heads,
|
||||
num_dummy_heads + self.num_heads,
|
||||
bias=config.qkv_bias,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
@ -117,34 +139,44 @@ class InternParallelAttention(nn.Module):
|
||||
self.qk_normalization = config.qk_normalization
|
||||
|
||||
if self.qk_normalization:
|
||||
self.q_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.k_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.q_norm = RMSNorm(self.dummy_dim,
|
||||
eps=config.layer_norm_eps,
|
||||
var_hidden_size=self.embed_dim)
|
||||
self.k_norm = RMSNorm(self.dummy_dim,
|
||||
eps=config.layer_norm_eps,
|
||||
var_hidden_size=self.embed_dim)
|
||||
|
||||
self.proj = RowParallelLinear(
|
||||
self.embed_dim,
|
||||
self.dummy_dim,
|
||||
self.embed_dim,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||
def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
|
||||
if self.tp_size > 1:
|
||||
q = tensor_model_parallel_all_gather(q.contiguous())
|
||||
k = tensor_model_parallel_all_gather(k.contiguous())
|
||||
q = self.q_norm.forward_native(q)
|
||||
k = self.k_norm.forward_native(k)
|
||||
if self.tp_size > 1:
|
||||
splitter = partial(split_tensor_along_last_dim,
|
||||
num_partitions=self.tp_size)
|
||||
q = splitter(q)[self.tp_rank]
|
||||
k = splitter(k)[self.tp_rank]
|
||||
return q, k
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, N, _ = x.shape
|
||||
qkv, _ = self.qkv(x)
|
||||
q, k, v = qkv.chunk(3, dim=-1)
|
||||
|
||||
if self.qk_normalization:
|
||||
q, k = self._apply_qk_norm(q, k)
|
||||
|
||||
q = q.view(B, N, self.num_heads_per_partition, self.head_dim)
|
||||
k = k.view(B, N, self.num_heads_per_partition, self.head_dim)
|
||||
v = v.view(B, N, self.num_heads_per_partition, self.head_dim)
|
||||
|
||||
if self.qk_normalization:
|
||||
B_, N_, H_, D_ = q.shape
|
||||
q = self.q_norm.forward_native(q.flatten(-2,
|
||||
-1)).view(B_, N_, H_, D_)
|
||||
k = self.k_norm.forward_native(k.flatten(-2,
|
||||
-1)).view(B_, N_, H_, D_)
|
||||
|
||||
x = xops.memory_efficient_attention_forward(q, k, v, scale=self.scale)
|
||||
x = x.view(B, N, -1)
|
||||
|
||||
@ -155,8 +187,14 @@ class InternParallelAttention(nn.Module):
|
||||
class InternSdpaAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
*,
|
||||
num_dummy_heads: int = 0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
@ -167,20 +205,27 @@ class InternSdpaAttention(nn.Module):
|
||||
f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
|
||||
f' {self.num_heads}).')
|
||||
|
||||
# Additional dummy heads are used to enable TP for common GPU counts.
|
||||
self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.qkv = nn.Linear(self.embed_dim,
|
||||
3 * self.embed_dim,
|
||||
3 * self.dummy_dim,
|
||||
bias=config.qkv_bias)
|
||||
|
||||
self.qk_normalization = config.qk_normalization
|
||||
|
||||
if self.qk_normalization:
|
||||
self.q_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.k_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.q_norm = RMSNorm(self.dummy_dim,
|
||||
eps=config.layer_norm_eps,
|
||||
var_hidden_size=self.embed_dim)
|
||||
self.k_norm = RMSNorm(self.dummy_dim,
|
||||
eps=config.layer_norm_eps,
|
||||
var_hidden_size=self.embed_dim)
|
||||
|
||||
self.proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.proj = nn.Linear(self.dummy_dim, self.embed_dim)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = qkv.chunk(3, dim=-1)
|
||||
@ -233,22 +278,23 @@ class InternMLP(nn.Module):
|
||||
|
||||
class InternVisionEncoderLayer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
*,
|
||||
num_dummy_heads: int = 0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.embed_dim = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.norm_type = config.norm_type
|
||||
|
||||
# fallback to sdpa attention if tp unavailable
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
num_heads = config.num_attention_heads
|
||||
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
|
||||
self.attn = InternParallelAttention(config,
|
||||
quant_config=quant_config)
|
||||
else:
|
||||
self.attn = InternSdpaAttention(config)
|
||||
self.attn = self._init_attn(config,
|
||||
quant_config,
|
||||
num_dummy_heads=num_dummy_heads)
|
||||
|
||||
self.mlp = InternMLP(config, quant_config=quant_config)
|
||||
self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
@ -260,6 +306,24 @@ class InternVisionEncoderLayer(nn.Module):
|
||||
self.ls2 = nn.Parameter(config.initializer_factor *
|
||||
torch.ones(self.embed_dim))
|
||||
|
||||
def _init_attn(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
*,
|
||||
num_dummy_heads: int,
|
||||
):
|
||||
# fallback to sdpa attention if tp unavailable
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
num_heads = config.num_attention_heads
|
||||
|
||||
if USE_XFORMERS_OPS and (num_heads + num_dummy_heads) % tp_size == 0:
|
||||
return InternParallelAttention(config,
|
||||
quant_config=quant_config,
|
||||
num_dummy_heads=num_dummy_heads)
|
||||
|
||||
return InternSdpaAttention(config, num_dummy_heads=num_dummy_heads)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -275,19 +339,27 @@ class InternVisionEncoderLayer(nn.Module):
|
||||
|
||||
class InternVisionEncoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
num_hidden_layers_override: Optional[int] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
*,
|
||||
num_hidden_layers_override: Optional[int] = None,
|
||||
num_dummy_heads: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
if num_hidden_layers_override is None:
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
else:
|
||||
num_hidden_layers = num_hidden_layers_override
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
InternVisionEncoderLayer(config=config, quant_config=quant_config)
|
||||
InternVisionEncoderLayer(config,
|
||||
quant_config,
|
||||
num_dummy_heads=num_dummy_heads)
|
||||
for _ in range(num_hidden_layers)
|
||||
])
|
||||
|
||||
@ -302,35 +374,25 @@ class InternVisionEncoder(nn.Module):
|
||||
|
||||
class InternVisionModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
num_hidden_layers_override: Optional[int] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
*,
|
||||
num_hidden_layers_override: Optional[int] = None,
|
||||
num_dummy_heads: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
self.embeddings = InternVisionEmbeddings(config)
|
||||
self.encoder = InternVisionEncoder(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override)
|
||||
|
||||
def resize_pos_embeddings(self, old_size, new_size, patch_size):
|
||||
pos_emb = self.embeddings.position_embedding
|
||||
_, num_positions, embed_dim = pos_emb.shape
|
||||
cls_emb = pos_emb[:, :1, :]
|
||||
pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size,
|
||||
old_size // patch_size,
|
||||
-1).permute(0, 3, 1, 2)
|
||||
pos_emb = F.interpolate(pos_emb.float(),
|
||||
size=new_size // patch_size,
|
||||
mode='bicubic',
|
||||
align_corners=False)
|
||||
pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim,
|
||||
-1).permute(0, 2, 1)
|
||||
pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
|
||||
self.embeddings.position_embedding = nn.Parameter(pos_emb)
|
||||
self.embeddings.image_size = new_size
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
num_dummy_heads=num_dummy_heads,
|
||||
)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings
|
||||
|
||||
@ -237,130 +237,173 @@ def get_max_internvl_image_size(ctx: InputContext,
|
||||
return width, height
|
||||
|
||||
|
||||
def input_processor_for_internvl(ctx: InputContext,
|
||||
llm_inputs: LLMInputs,
|
||||
*,
|
||||
max_dynamic_patch: Optional[int] = None):
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return llm_inputs
|
||||
class InternVLInputPipeline:
|
||||
|
||||
model_config = ctx.model_config
|
||||
hf_config = ctx.get_hf_config()
|
||||
def __init__(
|
||||
self,
|
||||
img_start_token: str,
|
||||
img_end_token: str,
|
||||
img_context_token: str,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
image_data = multi_modal_data["image"]
|
||||
num_patches = get_internvl_num_patches(hf_config)
|
||||
num_blocks_calculator = calculate_num_blocks_wrapper(
|
||||
hf_config, max_dynamic_patch)
|
||||
if isinstance(image_data, Image.Image):
|
||||
width, height = image_data.size
|
||||
num_blocks, _, _ = num_blocks_calculator(width, height)
|
||||
image_feature_size = [num_blocks * num_patches]
|
||||
elif is_list_of(image_data, Image.Image):
|
||||
image_feature_size = []
|
||||
for image in image_data:
|
||||
width, height = image.size
|
||||
self.img_start_token = img_start_token
|
||||
self.img_end_token = img_end_token
|
||||
self.img_context_token = img_context_token
|
||||
|
||||
def _create_image_prompt(self, feature_size: int, num_patches: int) -> str:
|
||||
return (self.img_start_token + self.img_context_token * feature_size +
|
||||
self.img_end_token)
|
||||
|
||||
def _expand_image_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
feature_sizes: List[int],
|
||||
num_patches: int,
|
||||
) -> str:
|
||||
image_idx = sorted(
|
||||
map(int, re.findall(r"Image-(\d+): <image>\n", prompt)))
|
||||
|
||||
new_prompt = prompt
|
||||
for idx, feature_size in enumerate(feature_sizes, start=1):
|
||||
image_prompt = self._create_image_prompt(feature_size, num_patches)
|
||||
if not image_idx:
|
||||
image_prompt = f"Image-{idx}: {image_prompt}"
|
||||
|
||||
new_prompt = new_prompt.replace('<image>', image_prompt, 1)
|
||||
|
||||
return new_prompt
|
||||
|
||||
def input_processor(
|
||||
self,
|
||||
ctx: InputContext,
|
||||
llm_inputs: LLMInputs,
|
||||
*,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
) -> LLMInputs:
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return llm_inputs
|
||||
|
||||
model_config = ctx.model_config
|
||||
hf_config = ctx.get_hf_config()
|
||||
|
||||
image_data = multi_modal_data["image"]
|
||||
num_patches = get_internvl_num_patches(hf_config)
|
||||
num_blocks_calculator = calculate_num_blocks_wrapper(
|
||||
hf_config, max_dynamic_patch)
|
||||
if isinstance(image_data, Image.Image):
|
||||
width, height = image_data.size
|
||||
num_blocks, _, _ = num_blocks_calculator(width, height)
|
||||
image_feature_size.append(num_blocks * num_patches)
|
||||
elif isinstance(image_data, torch.Tensor):
|
||||
num_images, image_feature_size, hidden_size = image_data.shape
|
||||
else:
|
||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||
image_feature_sizes = [num_blocks * num_patches]
|
||||
elif is_list_of(image_data, Image.Image):
|
||||
image_feature_sizes = []
|
||||
for image in image_data:
|
||||
width, height = image.size
|
||||
num_blocks, _, _ = num_blocks_calculator(width, height)
|
||||
image_feature_sizes.append(num_blocks * num_patches)
|
||||
elif isinstance(image_data, torch.Tensor):
|
||||
num_images, image_feature_size, hidden_size = image_data.shape
|
||||
image_feature_sizes = [image_feature_size]
|
||||
else:
|
||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||
|
||||
tokenizer = cached_get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
trust_remote_code=model_config.trust_remote_code)
|
||||
tokenizer = cached_get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
trust_remote_code=model_config.trust_remote_code)
|
||||
|
||||
prompt = llm_inputs.get("prompt")
|
||||
prompt_token_ids = llm_inputs["prompt_token_ids"]
|
||||
if prompt is None:
|
||||
prompt = tokenizer.decode(prompt_token_ids)
|
||||
prompt = llm_inputs.get("prompt")
|
||||
prompt_token_ids = llm_inputs["prompt_token_ids"]
|
||||
if prompt is None:
|
||||
prompt = tokenizer.decode(prompt_token_ids)
|
||||
|
||||
new_prompt = prompt
|
||||
image_idx = sorted(map(int, re.findall(r"Image-(\d+): <image>\n", prompt)))
|
||||
for idx, feature_size in enumerate(image_feature_size, start=1):
|
||||
image_prompt = IMG_START + IMG_CONTEXT * feature_size + IMG_END
|
||||
if not image_idx:
|
||||
image_prompt = f"Image-{idx}: {image_prompt}"
|
||||
new_prompt = new_prompt.replace('<image>', image_prompt, 1)
|
||||
new_prompt_token_ids = tokenizer.encode(new_prompt)
|
||||
new_prompt = self._expand_image_prompt(prompt, image_feature_sizes,
|
||||
num_patches)
|
||||
new_prompt_token_ids = tokenizer.encode(new_prompt)
|
||||
|
||||
return LLMInputs(prompt=prompt,
|
||||
prompt_token_ids=new_prompt_token_ids,
|
||||
multi_modal_data=multi_modal_data)
|
||||
return LLMInputs(prompt=prompt,
|
||||
prompt_token_ids=new_prompt_token_ids,
|
||||
multi_modal_data=multi_modal_data)
|
||||
|
||||
def input_mapper(
|
||||
self,
|
||||
ctx: InputContext,
|
||||
data: object,
|
||||
*,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
):
|
||||
hf_config = ctx.get_hf_config()
|
||||
|
||||
image_pixel_values_mapper = image_to_pixel_values_wrapper(
|
||||
hf_config, max_dynamic_patch)
|
||||
if isinstance(data, Image.Image):
|
||||
data = image_pixel_values_mapper(data)
|
||||
# Add an N dimension for number of images per prompt (currently 1).
|
||||
data = data.unsqueeze(0)
|
||||
elif is_list_of(data, Image.Image):
|
||||
# we can't stack here because images may have different num_patches
|
||||
data = [image_pixel_values_mapper(img) for img in data]
|
||||
model_config = ctx.model_config
|
||||
tokenizer = cached_get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
trust_remote_code=model_config.trust_remote_code)
|
||||
image_token_id = tokenizer.encode(self.img_context_token,
|
||||
add_special_tokens=False,
|
||||
return_tensors="pt")[0]
|
||||
|
||||
return MultiModalInputs({
|
||||
"pixel_values": data,
|
||||
"image_token_id": image_token_id
|
||||
})
|
||||
|
||||
def dummy_data(
|
||||
self,
|
||||
ctx: InputContext,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
*,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
):
|
||||
num_images = mm_counts["image"]
|
||||
|
||||
hf_config = ctx.get_hf_config()
|
||||
|
||||
image_feature_size = get_max_internvl_image_tokens(
|
||||
ctx, max_dynamic_patch=max_dynamic_patch)
|
||||
model_config = ctx.model_config
|
||||
tokenizer = cached_get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
trust_remote_code=model_config.trust_remote_code)
|
||||
|
||||
seq_data = dummy_seq_data_for_clip(
|
||||
hf_config.vision_config,
|
||||
seq_len,
|
||||
num_images,
|
||||
image_token_id=tokenizer.encode(self.img_context_token,
|
||||
add_special_tokens=False)[0],
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
|
||||
max_image_width, max_image_height = get_max_internvl_image_size(
|
||||
ctx, max_dynamic_patch=max_dynamic_patch)
|
||||
|
||||
mm_data = dummy_image_for_clip(
|
||||
hf_config.vision_config,
|
||||
num_images,
|
||||
image_width_override=max_image_width,
|
||||
image_height_override=max_image_height,
|
||||
)
|
||||
|
||||
return seq_data, mm_data
|
||||
|
||||
|
||||
def input_mapper_for_internvl(ctx: InputContext,
|
||||
data: object,
|
||||
*,
|
||||
max_dynamic_patch: Optional[int] = None):
|
||||
hf_config = ctx.get_hf_config()
|
||||
|
||||
image_pixel_values_mapper = image_to_pixel_values_wrapper(
|
||||
hf_config, max_dynamic_patch)
|
||||
if isinstance(data, Image.Image):
|
||||
data = image_pixel_values_mapper(data)
|
||||
# Add an N dimension for number of images per prompt (currently 1).
|
||||
data = data.unsqueeze(0)
|
||||
elif is_list_of(data, Image.Image):
|
||||
# we can't stack here because the images may have different num_patches
|
||||
data = [image_pixel_values_mapper(img) for img in data]
|
||||
model_config = ctx.model_config
|
||||
tokenizer = cached_get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
trust_remote_code=model_config.trust_remote_code)
|
||||
image_token_id = tokenizer.encode(IMG_CONTEXT,
|
||||
add_special_tokens=False,
|
||||
return_tensors="pt")[0]
|
||||
|
||||
return MultiModalInputs({
|
||||
"pixel_values": data,
|
||||
"image_token_id": image_token_id
|
||||
})
|
||||
input_pipeline = InternVLInputPipeline(IMG_START, IMG_END, IMG_CONTEXT)
|
||||
|
||||
|
||||
def dummy_data_for_internvl(ctx: InputContext,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
*,
|
||||
max_dynamic_patch: Optional[int] = None):
|
||||
num_images = mm_counts["image"]
|
||||
|
||||
hf_config = ctx.get_hf_config()
|
||||
|
||||
image_feature_size = get_max_internvl_image_tokens(
|
||||
ctx, max_dynamic_patch=max_dynamic_patch)
|
||||
model_config = ctx.model_config
|
||||
tokenizer = cached_get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
trust_remote_code=model_config.trust_remote_code)
|
||||
|
||||
seq_data = dummy_seq_data_for_clip(
|
||||
hf_config.vision_config,
|
||||
seq_len,
|
||||
num_images,
|
||||
image_token_id=tokenizer.encode(IMG_CONTEXT,
|
||||
add_special_tokens=False)[0],
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
|
||||
max_image_width, max_image_height = get_max_internvl_image_size(
|
||||
ctx, max_dynamic_patch=max_dynamic_patch)
|
||||
|
||||
mm_data = dummy_image_for_clip(
|
||||
hf_config.vision_config,
|
||||
num_images,
|
||||
image_width_override=max_image_width,
|
||||
image_height_override=max_image_height,
|
||||
)
|
||||
|
||||
return seq_data, mm_data
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_internvl)
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_pipeline.input_mapper)
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_internvl)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_internvl)
|
||||
@INPUT_REGISTRY.register_dummy_data(input_pipeline.dummy_data)
|
||||
@INPUT_REGISTRY.register_input_processor(input_pipeline.input_processor)
|
||||
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
def __init__(self,
|
||||
@ -388,20 +431,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
+ vision_feature_layer + 1
|
||||
else:
|
||||
num_hidden_layers = vision_feature_layer + 1
|
||||
self.vision_model = InternVisionModel(
|
||||
config.vision_config, num_hidden_layers_override=num_hidden_layers)
|
||||
self.vision_model = self._init_vision_model(config, num_hidden_layers)
|
||||
|
||||
self.language_model = init_vllm_registered_model(
|
||||
config.text_config, cache_config, quant_config)
|
||||
|
||||
vit_hidden_size = config.vision_config.hidden_size
|
||||
llm_hidden_size = config.text_config.hidden_size
|
||||
|
||||
self.mlp1 = nn.Sequential(
|
||||
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2),
|
||||
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,
|
||||
llm_hidden_size), nn.GELU(),
|
||||
nn.Linear(llm_hidden_size, llm_hidden_size))
|
||||
self.mlp1 = self._init_mlp1(config)
|
||||
|
||||
self.img_context_token_id = None
|
||||
self.make_empty_intermediate_tensors = (
|
||||
@ -414,6 +449,23 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
return Sampler()
|
||||
|
||||
def _init_vision_model(self, config: PretrainedConfig,
|
||||
num_hidden_layers: int):
|
||||
return InternVisionModel(config.vision_config,
|
||||
num_hidden_layers_override=num_hidden_layers)
|
||||
|
||||
def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
|
||||
vit_hidden_size = config.vision_config.hidden_size
|
||||
llm_hidden_size = config.text_config.hidden_size
|
||||
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2),
|
||||
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,
|
||||
llm_hidden_size),
|
||||
nn.GELU(),
|
||||
nn.Linear(llm_hidden_size, llm_hidden_size),
|
||||
)
|
||||
|
||||
def pixel_shuffle(self, x, scale_factor=0.5):
|
||||
n, w, h, c = x.size()
|
||||
# N, W, H, C --> N, W, H * scale, C // scale
|
||||
|
||||
64
vllm/model_executor/models/nvlm_d.py
Normal file
64
vllm/model_executor/models/nvlm_d.py
Normal file
@ -0,0 +1,64 @@
|
||||
# adapted from https://huggingface.co/nvidia/NVLM-D-72B/blob/main/modeling_nvlm_d.py
|
||||
# --------------------------------------------------------
|
||||
# NVLM-D
|
||||
# Copyright (c) 2024 NVIDIA
|
||||
# Licensed under Apache 2.0 License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
from .intern_vit import InternVisionModel
|
||||
from .internvl import (InternVLChatModel, InternVLInputPipeline,
|
||||
get_max_internvl_image_tokens)
|
||||
|
||||
IMG_START = '<|vision_start|>'
|
||||
IMG_END = '<|vision_end|>'
|
||||
IMG_CONTEXT = '<|vision_pad|>'
|
||||
|
||||
|
||||
class NVLMInputPipeline(InternVLInputPipeline):
|
||||
|
||||
def _create_image_prompt(self, feature_size: int, num_patches: int) -> str:
|
||||
tile_pos_identifiers = ([f"<tile_{i}>"
|
||||
for i in range(1, num_patches)] +
|
||||
["<tile_global_thumbnail>"])
|
||||
context_size = feature_size // num_patches
|
||||
|
||||
return '<Image>' + ''.join(
|
||||
tile_pos_identifier + self.img_context_token * context_size
|
||||
for tile_pos_identifier in tile_pos_identifiers) + '</Image>'
|
||||
|
||||
|
||||
input_pipeline = NVLMInputPipeline(IMG_START, IMG_END, IMG_CONTEXT)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_pipeline.input_mapper)
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(input_pipeline.dummy_data)
|
||||
@INPUT_REGISTRY.register_input_processor(input_pipeline.input_processor)
|
||||
class NVLM_D_Model(InternVLChatModel):
|
||||
|
||||
def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
|
||||
vit_hidden_size = config.vision_config.hidden_size
|
||||
llm_intermediate_size = config.text_config.intermediate_size
|
||||
llm_hidden_size = config.text_config.hidden_size
|
||||
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2),
|
||||
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,
|
||||
llm_intermediate_size,
|
||||
bias=False),
|
||||
nn.GELU(),
|
||||
nn.Linear(llm_intermediate_size, llm_hidden_size, bias=False),
|
||||
)
|
||||
|
||||
def _init_vision_model(self, config: PretrainedConfig,
|
||||
num_hidden_layers: int):
|
||||
# We added additional dummy heads to the original num of heads to make
|
||||
# the number of heads divisible by 8.
|
||||
return InternVisionModel(config.vision_config,
|
||||
num_hidden_layers_override=num_hidden_layers,
|
||||
num_dummy_heads=7)
|
||||
@ -16,6 +16,7 @@ from .interfaces_base import is_embedding_model, is_text_generation_model
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# yapf: disable
|
||||
_TEXT_GENERATION_MODELS = {
|
||||
# [Decoder-only]
|
||||
"AquilaModel": ("llama", "LlamaForCausalLM"),
|
||||
@ -68,8 +69,6 @@ _TEXT_GENERATION_MODELS = {
|
||||
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
|
||||
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
|
||||
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
|
||||
"Qwen2VLForConditionalGeneration":
|
||||
("qwen2_vl", "Qwen2VLForConditionalGeneration"),
|
||||
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
|
||||
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
||||
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
|
||||
@ -88,32 +87,25 @@ _EMBEDDING_MODELS = {
|
||||
}
|
||||
|
||||
_MULTIMODAL_MODELS = {
|
||||
"Blip2ForConditionalGeneration":
|
||||
("blip2", "Blip2ForConditionalGeneration"),
|
||||
"ChameleonForConditionalGeneration":
|
||||
("chameleon", "ChameleonForConditionalGeneration"),
|
||||
# [Decoder-only]
|
||||
"Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
|
||||
"ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501
|
||||
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
|
||||
"InternVLChatModel": ("internvl", "InternVLChatModel"),
|
||||
"LlavaForConditionalGeneration": ("llava",
|
||||
"LlavaForConditionalGeneration"),
|
||||
"LlavaNextForConditionalGeneration": ("llava_next",
|
||||
"LlavaNextForConditionalGeneration"),
|
||||
"LlavaNextVideoForConditionalGeneration":
|
||||
("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
|
||||
"LlavaOnevisionForConditionalGeneration":
|
||||
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
|
||||
"LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
|
||||
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
|
||||
"LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501
|
||||
"LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), # noqa: E501
|
||||
"MiniCPMV": ("minicpmv", "MiniCPMV"),
|
||||
"PaliGemmaForConditionalGeneration": ("paligemma",
|
||||
"PaliGemmaForConditionalGeneration"),
|
||||
"NVLM_D": ("nvlm_d", "NVLM_D_Model"),
|
||||
"PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501
|
||||
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
||||
"PixtralForConditionalGeneration": ("pixtral",
|
||||
"PixtralForConditionalGeneration"),
|
||||
"PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501
|
||||
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
|
||||
"Qwen2VLForConditionalGeneration": ("qwen2_vl",
|
||||
"Qwen2VLForConditionalGeneration"),
|
||||
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
|
||||
"UltravoxModel": ("ultravox", "UltravoxModel"),
|
||||
"MllamaForConditionalGeneration": ("mllama",
|
||||
"MllamaForConditionalGeneration"),
|
||||
# [Encoder-decoder]
|
||||
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
|
||||
}
|
||||
|
||||
_SPECULATIVE_DECODING_MODELS = {
|
||||
@ -121,6 +113,7 @@ _SPECULATIVE_DECODING_MODELS = {
|
||||
"MedusaModel": ("medusa", "Medusa"),
|
||||
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
|
||||
}
|
||||
# yapf: enable
|
||||
|
||||
_MODELS = {
|
||||
**_TEXT_GENERATION_MODELS,
|
||||
|
||||
@ -22,9 +22,9 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
|
||||
InternVLChatConfig, JAISConfig,
|
||||
MedusaConfig, MllamaConfig,
|
||||
MLPSpeculatorConfig, MPTConfig,
|
||||
NemotronConfig, Qwen2VLConfig,
|
||||
RWConfig, SolarConfig,
|
||||
UltravoxConfig)
|
||||
NemotronConfig, NVLM_D_Config,
|
||||
Qwen2VLConfig, RWConfig,
|
||||
SolarConfig, UltravoxConfig)
|
||||
# yapf: enable
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
|
||||
@ -54,6 +54,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
||||
"exaone": ExaoneConfig,
|
||||
"internvl_chat": InternVLChatConfig,
|
||||
"nemotron": NemotronConfig,
|
||||
"NVLM_D": NVLM_D_Config,
|
||||
"solar": SolarConfig,
|
||||
"ultravox": UltravoxConfig,
|
||||
"qwen2_vl": Qwen2VLConfig,
|
||||
|
||||
@ -13,6 +13,7 @@ from vllm.transformers_utils.configs.mllama import MllamaConfig
|
||||
from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
|
||||
from vllm.transformers_utils.configs.mpt import MPTConfig
|
||||
from vllm.transformers_utils.configs.nemotron import NemotronConfig
|
||||
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
|
||||
from vllm.transformers_utils.configs.qwen2vl import (Qwen2VLConfig,
|
||||
Qwen2VLVisionConfig)
|
||||
from vllm.transformers_utils.configs.solar import SolarConfig
|
||||
@ -31,6 +32,7 @@ __all__ = [
|
||||
"MllamaConfig",
|
||||
"MLPSpeculatorConfig",
|
||||
"NemotronConfig",
|
||||
"NVLM_D_Config",
|
||||
"SolarConfig",
|
||||
"UltravoxConfig",
|
||||
"Qwen2VLConfig",
|
||||
|
||||
12
vllm/transformers_utils/configs/nvlm_d.py
Normal file
12
vllm/transformers_utils/configs/nvlm_d.py
Normal file
@ -0,0 +1,12 @@
|
||||
# Adapted from
|
||||
# https://huggingface.co/nvidia/NVLM-D-72B/blob/main/configuration_nvlm_d.py
|
||||
# --------------------------------------------------------
|
||||
# NVLM-D
|
||||
# Copyright (c) 2024 NVIDIA
|
||||
# Licensed under Apache 2.0 License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
from .internvl import InternVLChatConfig
|
||||
|
||||
|
||||
class NVLM_D_Config(InternVLChatConfig):
|
||||
model_type = 'NVLM_D'
|
||||
Loading…
x
Reference in New Issue
Block a user