mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 08:04:58 +08:00
[MODEL ADDITION] Ovis2 Model Addition (#15826)
Signed-off-by: Marco <121761685+mlinmg@users.noreply.github.com> Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: isotr0py <2037008807@qq.com> Co-authored-by: Isotr0py <2037008807@qq.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
be633fba0f
commit
54072f315f
@ -1014,6 +1014,13 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
*
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
- * `Ovis2ForConditionalGeneration`<sup>^</sup>
|
||||
* Ovis2
|
||||
* T + I<sup>+</sup>
|
||||
* `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis2-2B`, etc.
|
||||
*
|
||||
*
|
||||
* ✅︎
|
||||
- * `PaliGemmaForConditionalGeneration`
|
||||
* PaliGemma, PaliGemma 2
|
||||
* T + I<sup>E</sup>
|
||||
|
||||
@ -725,6 +725,36 @@ def run_nvlm_d(questions: list[str], modality: str) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
# Ovis2
|
||||
def run_ovis2(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
|
||||
model_name = "AIDC-AI/Ovis2-1B"
|
||||
tokenizer = "Isotr0py/Ovis2-tokenizer"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
tokenizer=tokenizer,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
trust_remote_code=True,
|
||||
dtype="half",
|
||||
hf_overrides={"architectures": ["Ovis2ForConditionalGeneration"]},
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
placeholder = "<image>\n"
|
||||
prompts = [("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
||||
f"<|im_start|>user\n{placeholder}"
|
||||
f"{question}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n") for question in questions]
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# PaliGemma
|
||||
def run_paligemma(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
@ -1041,6 +1071,7 @@ model_example_map = {
|
||||
"llama4": run_llama4,
|
||||
"molmo": run_molmo,
|
||||
"NVLM_D": run_nvlm_d,
|
||||
"ovis2": run_ovis2,
|
||||
"paligemma": run_paligemma,
|
||||
"paligemma2": run_paligemma2,
|
||||
"phi3_v": run_phi3v,
|
||||
|
||||
@ -436,6 +436,36 @@ def load_nvlm_d(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
# Ovis2
|
||||
def load_ovis2(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "AIDC-AI/Ovis2-1B"
|
||||
tokenizer = "Isotr0py/Ovis2-tokenizer"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
tokenizer=tokenizer,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
trust_remote_code=True,
|
||||
dtype="half",
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
hf_overrides={"architectures": ["Ovis2ForConditionalGeneration"]},
|
||||
)
|
||||
|
||||
placeholder = '\n'.join(
|
||||
[f'Image {i+1}: <image>' for i in range(len(image_urls))]) + '\n'
|
||||
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
||||
f"<|im_start|>user\n{placeholder}"
|
||||
f"{question}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n")
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
image_data=[fetch_image(url) for url in image_urls],
|
||||
)
|
||||
|
||||
|
||||
def load_pixtral_hf(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "mistral-community/pixtral-12b"
|
||||
|
||||
@ -685,6 +715,7 @@ model_example_map = {
|
||||
"mistral3": load_mistral3,
|
||||
"mllama": load_mllama,
|
||||
"NVLM_D": load_nvlm_d,
|
||||
"ovis2": load_ovis2,
|
||||
"phi3_v": load_phi3v,
|
||||
"phi4_mm": load_phi4mm,
|
||||
"pixtral_hf": load_pixtral_hf,
|
||||
|
||||
@ -467,6 +467,18 @@ VLM_TEST_SETTINGS = {
|
||||
max_num_seqs=2,
|
||||
patch_hf_runner=model_utils.molmo_patch_hf_runner,
|
||||
),
|
||||
"ovis2": VLMTestInfo(
|
||||
models=["AIDC-AI/Ovis2-1B"],
|
||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||
prompt_formatter=lambda img_prompt: f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
|
||||
img_idx_to_prompt=lambda idx: "<image>\n", # noqa: E501
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
dtype="half",
|
||||
# use sdpa mode for hf runner since ovis2 didn't work with flash_attn
|
||||
hf_model_kwargs={"llm_attn_implementation": "sdpa"},
|
||||
patch_hf_runner=model_utils.ovis2_patch_hf_runner,
|
||||
),
|
||||
"phi3v": VLMTestInfo(
|
||||
models=["microsoft/Phi-3.5-vision-instruct"],
|
||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||
|
||||
@ -67,7 +67,7 @@ def run_test(
|
||||
"disable_mm_preprocessor_cache": True,
|
||||
}
|
||||
if model_info.tokenizer:
|
||||
vllm_runner_kwargs_["tokenizer"] = model_info.tokenizer
|
||||
vllm_runner_kwargs_["tokenizer_name"] = model_info.tokenizer
|
||||
if model_info.tokenizer_mode:
|
||||
vllm_runner_kwargs_["tokenizer_mode"] = model_info.tokenizer_mode
|
||||
if model_info.hf_overrides:
|
||||
|
||||
@ -676,3 +676,33 @@ def molmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
hf_model.model.generate = types.MethodType(_generate, hf_model.model)
|
||||
|
||||
return hf_model
|
||||
|
||||
|
||||
def ovis2_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
"""Patches and returns an instance of the HfRunner to use for Ovis2."""
|
||||
hf_model.model.visual_tokenizer.to(hf_model.dtype)
|
||||
hf_model.model.vte.to(hf_model.dtype)
|
||||
hf_model.model.llm.to(hf_model.dtype)
|
||||
|
||||
hf_model.model.get_output_embeddings = lambda: \
|
||||
hf_model.model.llm.get_output_embeddings()
|
||||
|
||||
def processor(*args, text="", images=None, **kwargs):
|
||||
text_tokenizer = hf_model.model.get_text_tokenizer()
|
||||
images = [images] if isinstance(images, Image) else images
|
||||
|
||||
text = text.split("<|im_start|>user\n")[1].split("<|im_end|>\n")[0]
|
||||
|
||||
prompt, input_ids, pixel_values = hf_model.model.preprocess_inputs(
|
||||
text_or_conversations=text, images=images)
|
||||
attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id)
|
||||
|
||||
inputs = {
|
||||
"inputs": input_ids.unsqueeze(0),
|
||||
"pixel_values": pixel_values.unsqueeze(0),
|
||||
"attention_mask": attention_mask.unsqueeze(0),
|
||||
}
|
||||
return BatchFeature(data=inputs, tensor_type="pt")
|
||||
|
||||
hf_model.processor = processor
|
||||
return hf_model
|
||||
|
||||
@ -274,6 +274,7 @@ def _test_processing_correctness_mistral(
|
||||
"allenai/Molmo-7B-D-0924",
|
||||
"allenai/Molmo-7B-O-0924",
|
||||
"nvidia/NVLM-D-72B",
|
||||
"AIDC-AI/Ovis2-1B",
|
||||
"google/paligemma-3b-mix-224",
|
||||
"google/paligemma2-3b-ft-docci-448",
|
||||
"microsoft/Phi-4-multimodal-instruct",
|
||||
|
||||
@ -348,6 +348,10 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
max_transformers_version="4.48",
|
||||
transformers_version_reason="Use of deprecated imports which have been removed.", # noqa: E501
|
||||
extras={"phi3.5": "microsoft/Phi-3.5-vision-instruct"}), # noqa: E501
|
||||
"Ovis2ForConditionalGeneration": _HfExamplesInfo("AIDC-AI/Ovis2-1B",
|
||||
tokenizer="Isotr0py/Ovis2-tokenizer",
|
||||
trust_remote_code=True,
|
||||
hf_overrides={"architectures": ["Ovis2ForConditionalGeneration"]}), # noqa: E501
|
||||
"Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct",
|
||||
trust_remote_code=True),
|
||||
"PixtralForConditionalGeneration": _HfExamplesInfo("mistralai/Pixtral-12B-2409", # noqa: E501
|
||||
|
||||
@ -496,9 +496,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
if model_type.startswith("llava"):
|
||||
return self._cached_token_str(self._tokenizer,
|
||||
hf_config.image_token_index)
|
||||
|
||||
if model_type in ("aya_vision", "chameleon", "deepseek_vl_v2",
|
||||
"internvl_chat", "skywork_chat", "NVLM_D",
|
||||
"h2ovl_chat", "idefics3", "smolvlm"):
|
||||
"internvl_chat", "ovis2", "skywork_chat",
|
||||
"NVLM_D", "h2ovl_chat", "idefics3", "smolvlm"):
|
||||
return "<image>"
|
||||
if model_type in ("mllama", "llama4"):
|
||||
return "<|image|>"
|
||||
|
||||
322
vllm/model_executor/models/aimv2.py
Normal file
322
vllm/model_executor/models/aimv2.py
Normal file
@ -0,0 +1,322 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# A modified implementation of the AIMv2 Transformer
|
||||
# inserted here also the image tokenizer used by Ovis2
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn, softmax
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.functional import gumbel_softmax, pad
|
||||
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.transformers_utils.configs.ovis2 import (AIMv2Config,
|
||||
Aimv2VisualTokenizerConfig)
|
||||
|
||||
IMAGE_INDICATOR_IDS = [-301, -302, -303, -304,
|
||||
-305] # kept for vocab prefixed tokens
|
||||
|
||||
|
||||
def st_argmax(y_soft: torch.Tensor, dim: int): # straight-through softmax
|
||||
index = y_soft.max(dim, keepdim=True)[1]
|
||||
y_hard = torch.zeros_like(
|
||||
y_soft, memory_format=torch.legacy_contiguous_format).scatter_(
|
||||
dim, index, 1.0)
|
||||
ret = y_hard - y_soft.detach() + y_soft
|
||||
return ret
|
||||
|
||||
|
||||
class Aimv2VisualTokenizer(torch.nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: Aimv2VisualTokenizerConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.backbone = AIMv2Model(
|
||||
config=config.backbone_config, # noqa
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.visual_tokenizer")
|
||||
# reserved tokens for IMAGE_INDICATORS
|
||||
head_dim = config.vocab_size - len(IMAGE_INDICATOR_IDS)
|
||||
self.head = torch.nn.Sequential(
|
||||
ReplicatedLinear(
|
||||
config.backbone_config.hidden_size * config.hidden_stride *
|
||||
config.hidden_stride,
|
||||
head_dim,
|
||||
bias=False,
|
||||
), torch.nn.LayerNorm(head_dim))
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.backbone.dtype
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.backbone.device
|
||||
|
||||
def tokenize(self, logits):
|
||||
if self.config.tokenize_function == 'softmax':
|
||||
tokens = softmax(logits, dim=-1)
|
||||
elif self.config.tokenize_function == 'gumbel_argmax':
|
||||
tokens = gumbel_softmax(logits, tau=self.config.tau, hard=True)
|
||||
elif self.config.tokenize_function == 'st_argmax':
|
||||
tokens = st_argmax(logits, dim=-1)
|
||||
else:
|
||||
raise ValueError(
|
||||
'Invalid `max_type`, expected softmax or gumbel_argmax '
|
||||
f'or st_argmax, but got {self.config.tokenize_function}')
|
||||
return tokens
|
||||
|
||||
def encode(self, pixel_values):
|
||||
features = self.backbone(pixel_values)
|
||||
if self.config.drop_cls_token:
|
||||
features = features[:, 1:, :]
|
||||
|
||||
# merge number of `hidden_stride * hidden_stride` hidden states together
|
||||
# to reduce token sequence length
|
||||
# e.g., for hidden_stride=2, this leads to a token length reduction:
|
||||
# 1024 -> 256 for aimv2
|
||||
if self.config.hidden_stride > 1:
|
||||
# this `d` maybe different from the above `d``
|
||||
n, L, d = features.shape
|
||||
sqrt_l = int(L**0.5)
|
||||
assert sqrt_l**2 == L, (
|
||||
"The token sequence length should be a perfect square.")
|
||||
features = features.reshape(n, sqrt_l, sqrt_l, d)
|
||||
pl = (self.config.hidden_stride -
|
||||
(sqrt_l %
|
||||
self.config.hidden_stride)) % self.config.hidden_stride
|
||||
features = pad(features, (0, 0, 0, pl, 0, pl), "constant", 0)
|
||||
sqrt_l += pl
|
||||
features = features.reshape(n, sqrt_l // self.config.hidden_stride,
|
||||
self.config.hidden_stride,
|
||||
sqrt_l // self.config.hidden_stride,
|
||||
self.config.hidden_stride, d)
|
||||
# [n, sqrt_l/hs, sqrt_l/hs, hs, hs, d]
|
||||
features = features.permute(0, 1, 3, 2, 4, 5)
|
||||
# [n, sqrt_l/hs, sqrt_l/hs, hs*hs*d]
|
||||
features = features.flatten(3)
|
||||
# [n, sqrt_l/hs*sqrt_l/hs, hs*hs*d]
|
||||
features = features.reshape(
|
||||
n, -1,
|
||||
self.config.hidden_stride * self.config.hidden_stride * d)
|
||||
|
||||
return features
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
"""[BatchSize, ImageShape] -> [BatchSize, Token, VocabSize]"""
|
||||
features = self.encode(pixel_values)
|
||||
logits, _ = self.head[0](
|
||||
features) # we spllit the sequncial here for not throwing an error
|
||||
logits = self.head[1](logits)
|
||||
tokens = self.tokenize(logits)
|
||||
# tokens' shape is [BatchSize, #Token, VocabSize-5], so padding with
|
||||
# [BatchSize, #Token, 5], after which, tokens' shape should become
|
||||
# [BatchSize, #Token, VocabSize]
|
||||
batch_size, token_len, _ = tokens.shape
|
||||
padding_tensor = torch.zeros(size=(batch_size, token_len,
|
||||
len(IMAGE_INDICATOR_IDS)),
|
||||
dtype=tokens.dtype,
|
||||
device=tokens.device,
|
||||
layout=tokens.layout,
|
||||
requires_grad=False)
|
||||
tokens = torch.cat((tokens, padding_tensor), dim=2)
|
||||
return tokens
|
||||
|
||||
|
||||
class AIMv2SwiGLUFFN(nn.Module):
|
||||
|
||||
def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig,
|
||||
prefix: str):
|
||||
super().__init__()
|
||||
hidden_features = config.intermediate_size
|
||||
in_features = config.hidden_size
|
||||
bias = config.use_bias
|
||||
|
||||
# TODO(Isotr0py): investigate if we can add TP to visual tokenizer
|
||||
self.fc1 = ReplicatedLinear(in_features,
|
||||
hidden_features,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc1")
|
||||
self.fc2 = ReplicatedLinear(hidden_features,
|
||||
in_features,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc2")
|
||||
self.fc3 = ReplicatedLinear(in_features,
|
||||
hidden_features,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc3")
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x_parallel, _ = self.fc1(x)
|
||||
gate, _ = self.fc3(x)
|
||||
x_parallel = F.silu(x_parallel) * gate
|
||||
out, _ = self.fc2(x_parallel)
|
||||
return out
|
||||
|
||||
|
||||
class AIMv2PatchEmbed(nn.Module):
|
||||
|
||||
def __init__(self, config: AIMv2Config):
|
||||
super().__init__()
|
||||
self.proj = nn.Conv2d(
|
||||
config.num_channels,
|
||||
config.hidden_size,
|
||||
kernel_size=(config.patch_size, config.patch_size),
|
||||
stride=(config.patch_size, config.patch_size),
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.proj(x).flatten(2).transpose(1, 2)
|
||||
x = self.norm.forward_native(x)
|
||||
return x
|
||||
|
||||
|
||||
class AIMv2ViTPreprocessor(nn.Module):
|
||||
|
||||
def __init__(self, config: AIMv2Config):
|
||||
super().__init__()
|
||||
num_patches = (config.image_size // config.patch_size)**2
|
||||
|
||||
self.patchifier = AIMv2PatchEmbed(config)
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros((1, num_patches, config.hidden_size)))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
tokens = self.patchifier(x)
|
||||
_, N, _ = tokens.shape
|
||||
pos_embed = self.pos_embed.to(tokens.device)
|
||||
tokens = tokens + pos_embed[:, :N]
|
||||
return tokens
|
||||
|
||||
|
||||
class AIMv2Attention(nn.Module):
|
||||
|
||||
def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig,
|
||||
prefix: str):
|
||||
super().__init__()
|
||||
dim = config.hidden_size
|
||||
|
||||
# TODO(Isotr0py): investigate if we can add TP to visual tokenizer
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.qkv = ReplicatedLinear(dim, dim * 3, bias=config.qkv_bias)
|
||||
# self.qkv = QKVParallelLinear(
|
||||
# hidden_size=dim,
|
||||
# head_size=dim // config.num_attention_heads,
|
||||
# total_num_heads=config.num_attention_heads,
|
||||
# bias=config.qkv_bias,
|
||||
# quant_config=quant_config,
|
||||
# prefix=f"{prefix}.qkv")
|
||||
self.proj = ReplicatedLinear(dim, dim, bias=config.use_bias)
|
||||
# self.proj = RowParallelLinear(input_size=dim,
|
||||
# output_size=dim,
|
||||
# bias = config.use_bias,
|
||||
# quant_config=quant_config,
|
||||
# prefix=f"{prefix}.proj")
|
||||
|
||||
def forward( # todo might implement multiple attn implementations
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
qkv, _ = self.qkv(x)
|
||||
|
||||
qkv = qkv.reshape(B, N, 3, self.num_heads,
|
||||
C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
|
||||
q, k, v = qkv.unbind(0)
|
||||
|
||||
x = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
|
||||
x = x.transpose(1, 2).contiguous().reshape(B, N, C)
|
||||
x, _ = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class AIMv2Block(nn.Module):
|
||||
|
||||
def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig,
|
||||
prefix: str):
|
||||
super().__init__()
|
||||
self.attn = AIMv2Attention(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
self.norm_1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.mlp = AIMv2SwiGLUFFN(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
self.norm_2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
x = x + self.attn(self.norm_1.forward_native(x), mask)
|
||||
x = x + self.mlp(self.norm_2.forward_native(x))
|
||||
return x
|
||||
|
||||
|
||||
class AIMv2Transformer(nn.Module):
|
||||
|
||||
def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig,
|
||||
prefix: str):
|
||||
super().__init__()
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
AIMv2Block(config, quant_config, prefix=f"{prefix}.blocks.{i}")
|
||||
for i in range(config.num_hidden_layers)
|
||||
])
|
||||
self.post_trunk_norm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tokens: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# they take the -1 as the ref embeddings, like a clip skip
|
||||
for block in self.blocks:
|
||||
tokens = block(tokens, mask)
|
||||
# NO NORM IN THE OG IMPLEMENTATION
|
||||
# tokens = self.post_trunk_norm(tokens)
|
||||
return tokens
|
||||
|
||||
|
||||
class AIMv2Model(torch.nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: AIMv2Config,
|
||||
quant_config: QuantizationConfig,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.preprocessor = AIMv2ViTPreprocessor(config)
|
||||
self.trunk = AIMv2Transformer(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.trunk")
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.trunk.blocks[0].attn.qkv.weight.dtype
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.trunk.blocks[0].attn.qkv.device
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
x = self.preprocessor(pixel_values)
|
||||
x = self.trunk(x, mask)
|
||||
|
||||
return x
|
||||
331
vllm/model_executor/models/ovis2.py
Normal file
331
vllm/model_executor/models/ovis2.py
Normal file
@ -0,0 +1,331 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/ovis/modeling_ovis.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2023 HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch Ovis2 model."""
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from transformers import BatchFeature
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.models.aimv2 import Aimv2VisualTokenizer
|
||||
from vllm.model_executor.models.utils import (AutoWeightsLoader, flatten_bn,
|
||||
init_vllm_registered_model,
|
||||
maybe_prefix)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalKwargs)
|
||||
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.ovis2 import OvisConfig
|
||||
from vllm.transformers_utils.processors.ovis2 import OvisProcessor
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal
|
||||
from .utils import merge_multimodal_embeddings
|
||||
|
||||
# Cannot find the following number from hf config.
|
||||
IMAGE_TOKEN = "<image>"
|
||||
IMAGE_ATOM_TOKEN_ID = 151666
|
||||
IMAGE_PAD_TOKEN_ID = 151672
|
||||
NUMBER_OF_TOKEN_TO_RESERVE_FOR_SEGMENT = 256
|
||||
|
||||
|
||||
class Ovis2ImagePatchInputs(TypedDict):
|
||||
type: Literal["image_patches"]
|
||||
flat_data: torch.Tensor
|
||||
"""
|
||||
Shape:
|
||||
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
|
||||
"""
|
||||
|
||||
patches_per_image: List[int]
|
||||
"""
|
||||
List of number of total patches for each image in the batch.
|
||||
This is used to restore the first two dimensions of `flat_data`.
|
||||
"""
|
||||
|
||||
|
||||
class VisualEmbedding(torch.nn.Embedding):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, visual_tokens: Tensor) -> Tensor:
|
||||
if visual_tokens.dtype in [
|
||||
torch.int8, torch.int16, torch.int32, torch.int64, torch.long
|
||||
]:
|
||||
return super().forward(visual_tokens)
|
||||
return torch.matmul(visual_tokens, self.weight)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.weight.device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.weight.dtype
|
||||
|
||||
|
||||
class Ovis2ProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_hf_config(self):
|
||||
return self.ctx.get_hf_config(OvisConfig)
|
||||
|
||||
def get_hf_processor(self, **kwargs):
|
||||
return self.ctx.get_hf_processor(OvisProcessor)
|
||||
|
||||
def get_image_processor(self) -> OvisProcessor:
|
||||
return self.get_hf_processor().image_processor # type: ignore
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return { # 32k is model token limit at the moment
|
||||
"image":
|
||||
self.get_hf_config().multimodal_max_length //
|
||||
((9 + 1) * NUMBER_OF_TOKEN_TO_RESERVE_FOR_SEGMENT)
|
||||
}
|
||||
|
||||
def get_image_size_with_most_features(self) -> ImageSize:
|
||||
image_processor = self.get_image_processor()
|
||||
return ImageSize(width=image_processor.size['shortest_edge'] * 9 * 2,
|
||||
height=image_processor.size['shortest_edge'] * 9 * 2)
|
||||
|
||||
|
||||
class Ovis2DummyInputsBuilder(BaseDummyInputsBuilder[Ovis2ProcessingInfo]):
|
||||
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
return IMAGE_TOKEN * num_images
|
||||
|
||||
def get_dummy_mm_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images),
|
||||
}
|
||||
return mm_data
|
||||
|
||||
|
||||
class Ovis2MultiModalProcessor(BaseMultiModalProcessor[Ovis2ProcessingInfo]):
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
if not mm_data:
|
||||
# # Avoid warning from HF logger for text-only input
|
||||
prompt_ids = self.info.get_tokenizer().encode(prompt)
|
||||
# prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) nope
|
||||
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
|
||||
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
|
||||
return processed_outputs
|
||||
|
||||
def _apply_hf_processor_tokens_only(
|
||||
self,
|
||||
prompt_tokens: list[int],
|
||||
) -> list[int]:
|
||||
|
||||
return prompt_tokens
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
grids=MultiModalFieldConfig.batched("image"))
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
|
||||
def get_replacement_ovis(item_idx):
|
||||
grid = out_mm_kwargs["grids"][item_idx]
|
||||
|
||||
hf_processor = self.info.get_hf_processor()
|
||||
return hf_processor.construct_image_placeholders(grid)
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=IMAGE_TOKEN,
|
||||
replacement=get_replacement_ovis,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(Ovis2MultiModalProcessor,
|
||||
info=Ovis2ProcessingInfo,
|
||||
dummy_inputs=Ovis2DummyInputsBuilder)
|
||||
class Ovis2ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.config: OvisConfig = config
|
||||
self.llm = init_vllm_registered_model(
|
||||
vllm_config=vllm_config.with_hf_config(config.get_text_config()),
|
||||
prefix=maybe_prefix(prefix, "llm"),
|
||||
)
|
||||
|
||||
self.visual_tokenizer = Aimv2VisualTokenizer(
|
||||
config=config.visual_tokenizer_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.visual_tokenizer",
|
||||
image_processor_name_or_path=config.visual_tokenizer_config.
|
||||
backbone_config.name_or_path,
|
||||
)
|
||||
|
||||
self.vte = VisualEmbedding(
|
||||
self.config.visual_tokenizer_config.vocab_size,
|
||||
self.config.hidden_size)
|
||||
|
||||
# TODO(Isotr0py): PP support
|
||||
# self.make_empty_intermediate_tensors = (
|
||||
# self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[Ovis2ImagePatchInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
if pixel_values is None:
|
||||
return None
|
||||
|
||||
if pixel_values is not None:
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
return Ovis2ImagePatchInputs(
|
||||
type="image_patches",
|
||||
flat_data=flatten_bn(flatten_bn(pixel_values), concat=True),
|
||||
patches_per_image=[
|
||||
x.shape[0] for x in flatten_bn(pixel_values)
|
||||
],
|
||||
)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _process_image_input(
|
||||
self, image_input: Ovis2ImagePatchInputs) -> MultiModalEmbeddings:
|
||||
image_patches_flat = image_input["flat_data"]
|
||||
patches_per_image = image_input["patches_per_image"]
|
||||
|
||||
target_dtype = self.visual_tokenizer.dtype
|
||||
visual_tokens = self.visual_tokenizer(
|
||||
image_patches_flat.to(target_dtype))
|
||||
visual_embeds = self.vte(visual_tokens) # 1:1 numeric eq.
|
||||
|
||||
return tuple(
|
||||
x.flatten(0, 1)
|
||||
for x in visual_embeds.split(patches_per_image, dim=0))
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return None
|
||||
|
||||
image_features = self._process_image_input(image_input)
|
||||
|
||||
return image_features
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.llm.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
[IMAGE_ATOM_TOKEN_ID, IMAGE_PAD_TOKEN_ID])
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
input_ids = None
|
||||
|
||||
# up until here we have a inputs_embeds 100% numerical identity
|
||||
# between the OG HF Transformers implementation and ours
|
||||
hidden_states = self.llm(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.llm.logits_processor(self.llm.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.llm
|
||||
@ -195,6 +195,7 @@ _MULTIMODAL_MODELS = {
|
||||
"Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"), # noqa: E501
|
||||
"MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
|
||||
"NVLM_D": ("nvlm_d", "NVLM_D_Model"),
|
||||
"Ovis2ForConditionalGeneration": ("ovis2", "Ovis2ForConditionalGeneration"),
|
||||
"PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501
|
||||
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
||||
"PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501
|
||||
|
||||
@ -38,9 +38,9 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
|
||||
MiniMaxVL01Config, MllamaConfig,
|
||||
MLPSpeculatorConfig, MPTConfig,
|
||||
NemotronConfig, NVLM_D_Config,
|
||||
RWConfig, SkyworkR1VChatConfig,
|
||||
SolarConfig, Telechat2Config,
|
||||
UltravoxConfig)
|
||||
OvisConfig, RWConfig,
|
||||
SkyworkR1VChatConfig, SolarConfig,
|
||||
Telechat2Config, UltravoxConfig)
|
||||
# yapf: enable
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
from vllm.utils import resolve_obj_by_qualname
|
||||
@ -79,6 +79,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
||||
"minimax_vl_01": MiniMaxVL01Config,
|
||||
"nemotron": NemotronConfig,
|
||||
"NVLM_D": NVLM_D_Config,
|
||||
"ovis": OvisConfig,
|
||||
"solar": SolarConfig,
|
||||
"skywork_chat": SkyworkR1VChatConfig,
|
||||
"telechat": Telechat2Config,
|
||||
|
||||
@ -23,6 +23,7 @@ from vllm.transformers_utils.configs.moonvit import MoonViTConfig
|
||||
from vllm.transformers_utils.configs.mpt import MPTConfig
|
||||
from vllm.transformers_utils.configs.nemotron import NemotronConfig
|
||||
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
|
||||
from vllm.transformers_utils.configs.ovis2 import OvisConfig
|
||||
from vllm.transformers_utils.configs.skyworkr1v import SkyworkR1VChatConfig
|
||||
from vllm.transformers_utils.configs.solar import SolarConfig
|
||||
from vllm.transformers_utils.configs.telechat2 import Telechat2Config
|
||||
@ -49,6 +50,7 @@ __all__ = [
|
||||
"KimiVLConfig",
|
||||
"NemotronConfig",
|
||||
"NVLM_D_Config",
|
||||
"OvisConfig",
|
||||
"SkyworkR1VChatConfig",
|
||||
"SolarConfig",
|
||||
"Telechat2Config",
|
||||
|
||||
170
vllm/transformers_utils/configs/ovis2.py
Normal file
170
vllm/transformers_utils/configs/ovis2.py
Normal file
@ -0,0 +1,170 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# yapf: disable
|
||||
# ruff: noqa: E501
|
||||
# copied from https://huggingface.co/AIDC-AI/Ovis2-1B/blob/main/configuration_aimv2.py
|
||||
# and https://huggingface.co/AIDC-AI/Ovis2-1B/blob/main/configuration_ovis.py
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from transformers import AutoConfig, PretrainedConfig
|
||||
|
||||
|
||||
class AIMv2Config(PretrainedConfig):
|
||||
"""This is the configuration class to store the configuration of an [`AIMv2Model`].
|
||||
|
||||
Instantiating a configuration with the defaults will yield a similar configuration
|
||||
to that of the [apple/aimv2-large-patch14-224](https://huggingface.co/apple/aimv2-large-patch14-224).
|
||||
|
||||
Args:
|
||||
hidden_size: Dimension of the hidden representations.
|
||||
intermediate_size: Dimension of the SwiGLU representations.
|
||||
num_hidden_layers: Number of hidden layers in the Transformer.
|
||||
num_attention_heads: Number of attention heads for each attention layer
|
||||
in the Transformer.
|
||||
num_channels: Number of input channels.
|
||||
image_size: Image size.
|
||||
patch_size: Patch size.
|
||||
rms_norm_eps: Epsilon value used for the RMS normalization layer.
|
||||
attention_dropout: Dropout ratio for attention probabilities.
|
||||
projection_dropout: Dropout ratio for the projection layer after the attention.
|
||||
qkv_bias: Whether to add a bias to the queries, keys and values.
|
||||
use_bias: Whether to add a bias in the feed-forward and projection layers.
|
||||
kwargs: Keyword arguments for the [`PretrainedConfig`].
|
||||
"""
|
||||
|
||||
model_type: str = "aimv2"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 1024,
|
||||
intermediate_size: int = 2816,
|
||||
num_hidden_layers: int = 24,
|
||||
num_attention_heads: int = 8,
|
||||
num_channels: int = 3,
|
||||
image_size: int = 224,
|
||||
patch_size: int = 14,
|
||||
rms_norm_eps: float = 1e-5,
|
||||
attention_dropout: float = 0.0,
|
||||
projection_dropout: float = 0.0,
|
||||
qkv_bias: bool = False,
|
||||
use_bias: bool = False,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_channels = num_channels
|
||||
self.patch_size = patch_size
|
||||
self.image_size = image_size
|
||||
self.attention_dropout = attention_dropout
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
|
||||
self.projection_dropout = projection_dropout
|
||||
self.qkv_bias = qkv_bias
|
||||
self.use_bias = use_bias
|
||||
|
||||
|
||||
IGNORE_ID = -100
|
||||
IMAGE_TOKEN_ID = -200
|
||||
IMAGE_TOKEN = "<image>"
|
||||
IMAGE_ATOM_ID = -300
|
||||
IMAGE_INDICATOR_IDS = [-301, -302, -303, -304, -305]
|
||||
|
||||
AutoConfig.register("aimv2", AIMv2Config)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Visual Tokenizer Configuration
|
||||
# ----------------------------------------------------------------------
|
||||
class BaseVisualTokenizerConfig(PretrainedConfig):
|
||||
|
||||
def __init__(self,
|
||||
vocab_size=16384,
|
||||
tokenize_function="softmax",
|
||||
tau=1.0,
|
||||
depths=None,
|
||||
drop_cls_token=False,
|
||||
backbone_config: Optional[Union[PretrainedConfig,
|
||||
dict]] = None,
|
||||
hidden_stride: int = 1,
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.vocab_size = vocab_size
|
||||
self.tokenize_function = tokenize_function
|
||||
self.tau = tau
|
||||
if isinstance(depths, str):
|
||||
depths = [int(x) for x in depths.split('|')]
|
||||
self.depths = depths
|
||||
self.backbone_kwargs = dict[str, Any]()
|
||||
self.drop_cls_token = drop_cls_token
|
||||
if backbone_config is not None:
|
||||
assert isinstance(backbone_config, (PretrainedConfig, dict)), \
|
||||
f"expect `backbone_config` to be instance of PretrainedConfig or dict, but got {type(backbone_config)} type"
|
||||
if not isinstance(backbone_config, PretrainedConfig):
|
||||
model_type = backbone_config['model_type']
|
||||
backbone_config.pop('model_type')
|
||||
backbone_config = AutoConfig.for_model(model_type,
|
||||
**backbone_config)
|
||||
self.backbone_config = backbone_config
|
||||
self.hidden_stride = hidden_stride
|
||||
|
||||
|
||||
class Aimv2VisualTokenizerConfig(BaseVisualTokenizerConfig):
|
||||
model_type = "aimv2_visual_tokenizer"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if self.drop_cls_token:
|
||||
self.drop_cls_token = False
|
||||
if self.depths:
|
||||
assert len(self.depths) == 1
|
||||
self.backbone_kwargs['num_hidden_layers'] = self.depths[0]
|
||||
|
||||
|
||||
AutoConfig.register("aimv2_visual_tokenizer", Aimv2VisualTokenizerConfig)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Ovis Configuration
|
||||
# ----------------------------------------------------------------------
|
||||
class OvisConfig(PretrainedConfig):
|
||||
model_type = "ovis"
|
||||
|
||||
def __init__(self,
|
||||
llm_config: Optional[Union[PretrainedConfig, dict]] = None,
|
||||
visual_tokenizer_config: Optional[Union[PretrainedConfig,
|
||||
dict]] = None,
|
||||
multimodal_max_length=8192,
|
||||
hidden_size=None,
|
||||
conversation_formatter_class=None,
|
||||
llm_attn_implementation=None,
|
||||
disable_tie_weight=False,
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if llm_config is not None:
|
||||
assert isinstance(llm_config, (PretrainedConfig, dict)), \
|
||||
f"expect `llm_config` to be instance of PretrainedConfig or dict, but got {type(llm_config)} type"
|
||||
if not isinstance(llm_config, PretrainedConfig):
|
||||
model_type = llm_config['model_type']
|
||||
llm_config.pop('model_type')
|
||||
llm_config = AutoConfig.for_model(model_type, **llm_config)
|
||||
|
||||
# map llm_config to text_config
|
||||
self.text_config = llm_config
|
||||
if visual_tokenizer_config is not None:
|
||||
assert isinstance(visual_tokenizer_config, (PretrainedConfig, dict)), \
|
||||
f"expect `visual_tokenizer_config` to be instance of PretrainedConfig or dict, but got {type(visual_tokenizer_config)} type"
|
||||
if not isinstance(visual_tokenizer_config, PretrainedConfig):
|
||||
model_type = visual_tokenizer_config['model_type']
|
||||
visual_tokenizer_config.pop('model_type')
|
||||
visual_tokenizer_config = AutoConfig.for_model(
|
||||
model_type, **visual_tokenizer_config)
|
||||
|
||||
self.visual_tokenizer_config = visual_tokenizer_config
|
||||
self.multimodal_max_length = multimodal_max_length
|
||||
self.hidden_size = hidden_size
|
||||
self.conversation_formatter_class = conversation_formatter_class
|
||||
self.llm_attn_implementation = llm_attn_implementation
|
||||
self.disable_tie_weight = disable_tie_weight
|
||||
@ -2,5 +2,6 @@
|
||||
|
||||
from vllm.transformers_utils.processors.deepseek_vl2 import (
|
||||
DeepseekVLV2Processor)
|
||||
from vllm.transformers_utils.processors.ovis2 import OvisProcessor
|
||||
|
||||
__all__ = ["DeepseekVLV2Processor"]
|
||||
__all__ = ["DeepseekVLV2Processor", "OvisProcessor"]
|
||||
|
||||
397
vllm/transformers_utils/processors/ovis2.py
Normal file
397
vllm/transformers_utils/processors/ovis2.py
Normal file
@ -0,0 +1,397 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# yapf: disable
|
||||
# ruff: noqa: E501
|
||||
# coding=utf-8
|
||||
# adapted from https://github.com/AIDC-AI/Ovis/blob/35ab51a1a1e3542fa6db260a1084cefbc8f164bb/ovis/vllm/processing_ovis.py
|
||||
# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List, Union
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import AutoProcessor, BatchFeature
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.processing_utils import (ProcessingKwargs, ProcessorMixin,
|
||||
Unpack)
|
||||
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
|
||||
__all__ = [ 'OvisProcessor']
|
||||
IGNORE_ID = -100
|
||||
|
||||
class OvisProcessorKwargs(ProcessingKwargs, total=False): # type: ignore[call-arg]
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"padding": False,
|
||||
},
|
||||
"images_kwargs": {
|
||||
'max_partition':9,
|
||||
'covering_threshold':0.9,
|
||||
'convert_to_rgb':True,
|
||||
'return_tensors':'pt'},
|
||||
}
|
||||
|
||||
|
||||
|
||||
class OvisProcessor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a Ovis processor which wraps a Ovis image processor and a Qwen2 tokenizer into a single processor.
|
||||
[`OvisProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the
|
||||
[`~OvisProcessor.__call__`] and [`~OvisProcessor.decode`] for more information.
|
||||
Args:
|
||||
image_processor ([`Qwen2VLImageProcessor`], *optional*):
|
||||
The image processor is a required input.
|
||||
tokenizer ([`Qwen2TokenizerFast`], *optional*):
|
||||
The tokenizer is a required input.
|
||||
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||
in a chat into a tokenizable string.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
valid_kwargs = ["chat_template"]
|
||||
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = "Qwen2Tokenizer"
|
||||
|
||||
def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
|
||||
self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
|
||||
self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
|
||||
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||
|
||||
self.extra_special_tokens = {
|
||||
"image_token": "<image>",
|
||||
"image_atom": "<image_atom>",
|
||||
"image_start": "<img>",
|
||||
"image_prefix": "<pre>",
|
||||
"image_col_sep": "<col>",
|
||||
"image_row_sep": "<row>",
|
||||
"image_end": "</img>",
|
||||
'image_pad': '<image_pad>',
|
||||
}
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images: ImageInput = None,
|
||||
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
||||
**kwargs: Unpack[OvisProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
||||
and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
|
||||
the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
|
||||
Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
|
||||
Args:
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. Both channels-first and channels-last formats are supported.
|
||||
text (`str`, `List[str]`, `List[List[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
|
||||
tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
If set, will return tensors of a particular framework. Acceptable values are:
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||
Returns:
|
||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
||||
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
||||
`None`).
|
||||
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||
- **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
|
||||
- **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
|
||||
- **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
|
||||
- **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`.
|
||||
"""
|
||||
output_kwargs = self._merge_kwargs(
|
||||
OvisProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Process all images first
|
||||
image_features = {}
|
||||
if images is not None:
|
||||
processed_images = []
|
||||
image_placeholders_list = []
|
||||
grids = []
|
||||
|
||||
# Process each image
|
||||
for image in images if isinstance(images, list) else [images]:
|
||||
pixel_values, image_placeholders, grid = self.preprocess_image(
|
||||
image=image, **output_kwargs["images_kwargs"]
|
||||
)
|
||||
processed_images.append(pixel_values)
|
||||
image_placeholders_list.append(image_placeholders)
|
||||
grids.append(grid)
|
||||
|
||||
# assign all processed images
|
||||
if processed_images:
|
||||
image_features["image_placeholders"] = image_placeholders_list
|
||||
|
||||
# Process text input
|
||||
if text is not None:
|
||||
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
|
||||
tokenized_batched_text = self.tokenizer.batch_encode_plus(
|
||||
text,
|
||||
**output_kwargs["text_kwargs"]
|
||||
)
|
||||
image_token_id = self.get_token_value("image_token")
|
||||
replaced_ids_list = []
|
||||
replaced_attn_mask_list = []
|
||||
idx = 0
|
||||
for ids_tensor, attn_mask in zip(tokenized_batched_text['input_ids'],
|
||||
tokenized_batched_text['attention_mask']):
|
||||
if image_token_id in ids_tensor and "image_placeholders" in image_features:
|
||||
if idx < len(image_features["image_placeholders"]):
|
||||
# Converts in list for ease of use
|
||||
ids_list = ids_tensor.tolist()
|
||||
attn_list = attn_mask.tolist()
|
||||
|
||||
new_ids = []
|
||||
new_attn = []
|
||||
|
||||
# replace placeholders
|
||||
for i, token_id in enumerate(ids_list):
|
||||
if token_id == image_token_id:
|
||||
placeholder_ids = image_features["image_placeholders"][idx]
|
||||
new_ids.extend(placeholder_ids)
|
||||
new_attn.extend([1] * len(placeholder_ids))
|
||||
idx += 1
|
||||
else:
|
||||
new_ids.append(token_id)
|
||||
new_attn.append(attn_list[i])
|
||||
|
||||
# Converts back to tensors
|
||||
ids_tensor = torch.tensor(new_ids, dtype=torch.long)
|
||||
attn_mask = torch.tensor(new_attn, dtype=torch.long)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'Mismatch between the images you provided and the number of placeholder present in the text')
|
||||
|
||||
replaced_ids_list.append(ids_tensor)
|
||||
replaced_attn_mask_list.append(attn_mask)
|
||||
|
||||
if replaced_ids_list:
|
||||
replaced_and_tokenized_ids = torch.stack(replaced_ids_list)
|
||||
replaced_and_tokenized_attn_mask = torch.stack(replaced_attn_mask_list)
|
||||
else:
|
||||
replaced_and_tokenized_ids = torch.tensor([], dtype=torch.long)
|
||||
replaced_and_tokenized_attn_mask = torch.tensor([], dtype=torch.long)
|
||||
|
||||
# Create the output with text features
|
||||
output = BatchFeature(
|
||||
data={
|
||||
"input_ids": replaced_and_tokenized_ids,
|
||||
"attention_mask": replaced_and_tokenized_attn_mask,
|
||||
}
|
||||
)
|
||||
|
||||
# Add image features if present
|
||||
if image_features:
|
||||
output["pixel_values"] = processed_images
|
||||
output['grids'] = grids
|
||||
|
||||
return output
|
||||
|
||||
|
||||
# If only images were provided
|
||||
return BatchFeature(data=image_features)
|
||||
|
||||
|
||||
def get_image_size(self):
|
||||
height = self.image_processor.crop_size["height"]
|
||||
width = self.image_processor.crop_size["width"]
|
||||
return height, width
|
||||
|
||||
def get_token_value(self, tok):
|
||||
return self.tokenizer.get_vocab()[self.extra_special_tokens[tok]]
|
||||
|
||||
def construct_image_placeholders(self, grid):
|
||||
|
||||
image_placeholders = [self.get_token_value('image_start'),
|
||||
self.get_token_value('image_atom'),
|
||||
self.get_token_value('image_prefix')]
|
||||
if grid[0] * grid[1] > 1:
|
||||
for r in range(grid[0]):
|
||||
for c in range(grid[1]):
|
||||
image_placeholders.append(self.get_token_value('image_atom') )
|
||||
if c < grid[1] - 1:
|
||||
image_placeholders.append(self.get_token_value('image_col_sep'))
|
||||
if r < grid[0] - 1:
|
||||
image_placeholders.append(self.get_token_value('image_row_sep'))
|
||||
image_placeholders.append(self.get_token_value('image_end'))
|
||||
# return image_placeholders
|
||||
|
||||
image_atom_token_id = self.get_token_value('image_atom')
|
||||
# Extract the padding token ID from tokenizer
|
||||
image_padding_token_id = self.get_token_value('image_pad')
|
||||
|
||||
# Create a new list with padding tokens inserted
|
||||
padded_placeholder_tokens = []
|
||||
for token in image_placeholders:
|
||||
padded_placeholder_tokens.append(token)
|
||||
if token == image_atom_token_id:
|
||||
# Add 255 padding tokens after each image atom token
|
||||
padded_placeholder_tokens.extend([image_padding_token_id] * 255)
|
||||
return padded_placeholder_tokens
|
||||
|
||||
def preprocess_image(self, image: PIL.Image.Image, max_partition, covering_threshold, convert_to_rgb, return_tensors):
|
||||
def _preprocess(img: PIL.Image.Image, side):
|
||||
# first resize and preprocess
|
||||
w, h = img.size
|
||||
if w == h:
|
||||
new_width = new_height = side
|
||||
elif w > h:
|
||||
new_width = side
|
||||
new_height = int(h / w * new_width)
|
||||
else:
|
||||
new_height = side
|
||||
new_width = int(w / h * new_height)
|
||||
new_size = dict(height=new_height, width=new_width)
|
||||
pixel_values = self.image_processor.preprocess(img, size=new_size, return_tensors=return_tensors)['pixel_values']
|
||||
|
||||
# then pad to square
|
||||
square_values = torch.zeros([1, 3, side, side], dtype=pixel_values.dtype, device=pixel_values.device)
|
||||
new_height, new_width = pixel_values.shape[2:]
|
||||
if new_height == new_width:
|
||||
square_values[:, :, :, :] = pixel_values
|
||||
elif new_height > new_width:
|
||||
from_index = (side - new_width) // 2
|
||||
square_values[:, :, :, from_index:from_index + new_width] = pixel_values
|
||||
else:
|
||||
from_index = (side - new_height) // 2
|
||||
square_values[:, :, from_index:from_index + new_height, :] = pixel_values
|
||||
|
||||
return square_values
|
||||
|
||||
def _partition(img, grid) -> list[tuple[int, int, int, int]]:
|
||||
w, h = img.size
|
||||
row_height = h // grid[0]
|
||||
col_width = w // grid[1]
|
||||
|
||||
partition = []
|
||||
for row in range(grid[0]):
|
||||
for col in range(grid[1]):
|
||||
left = col * col_width
|
||||
upper = row * row_height
|
||||
right = w if col == grid[1] - 1 else (col + 1) * col_width
|
||||
lower = h if row == grid[0] - 1 else (row + 1) * row_height
|
||||
partition.append((left, upper, right, lower))
|
||||
|
||||
return partition
|
||||
|
||||
def _covering_area(left, upper, right, lower, side):
|
||||
w = right - left
|
||||
h = lower - upper
|
||||
w, h = max(w, h), min(w, h)
|
||||
if w > side:
|
||||
h = h / w * side
|
||||
w = side
|
||||
return w * h
|
||||
|
||||
def _get_best_grid(img, side):
|
||||
img_area = img.size[0] * img.size[1]
|
||||
|
||||
candidate_grids = []
|
||||
for i in range(1, max_partition + 1):
|
||||
for j in range(1, max_partition + 1):
|
||||
if i * j <= max_partition:
|
||||
candidate_grids.append((i, j))
|
||||
|
||||
all_grids = []
|
||||
good_grids = []
|
||||
for grid in candidate_grids:
|
||||
partition = _partition(img, grid)
|
||||
covering_ratio = sum([_covering_area(*p, side) for p in partition]) / img_area
|
||||
assert covering_ratio <= 1.0
|
||||
all_grids.append((grid, covering_ratio))
|
||||
if covering_ratio > covering_threshold:
|
||||
good_grids.append((grid, covering_ratio))
|
||||
|
||||
if len(good_grids) > 0:
|
||||
# pick the good partition with minimum #sub_images and break the tie using covering_ratio
|
||||
return sorted(good_grids, key=lambda x: (x[0][0] * x[0][1], -x[1]))[0][0]
|
||||
else:
|
||||
# pick the partition with maximum covering_ratio and break the tie using #sub_images
|
||||
return sorted(all_grids, key=lambda x: (-x[1], x[0][0] * x[0][1]))[0][0]
|
||||
|
||||
if convert_to_rgb and image.mode != 'RGB':
|
||||
image = image.convert('RGB')
|
||||
|
||||
|
||||
sides = self.get_image_size()
|
||||
if sides[0] != sides[1]:
|
||||
raise ValueError('get_image_size() returns non-square size')
|
||||
side = sides[0]
|
||||
grid = _get_best_grid(image, side)
|
||||
partition = _partition(image, grid)
|
||||
crops = [image.crop(p) for p in partition]
|
||||
if len(crops) > 1:
|
||||
crops.insert(0, image)
|
||||
pixel_values = torch.cat([_preprocess(crop, side) for crop in crops], dim=0)
|
||||
image_placeholders = self.construct_image_placeholders(grid)
|
||||
return pixel_values, image_placeholders, grid
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||
the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
def post_process_image_text_to_text(self, generated_outputs):
|
||||
"""
|
||||
Post-process the output of the model to decode the text.
|
||||
Args:
|
||||
generated_outputs (`torch.Tensor` or `np.ndarray`):
|
||||
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
|
||||
or `(sequence_length,)`.
|
||||
Returns:
|
||||
`List[str]`: The decoded text.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(
|
||||
generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
|
||||
@property
|
||||
def model_input_names(self):
|
||||
tokenizer_input_names = self.tokenizer.model_input_names
|
||||
image_processor_input_names = self.image_processor.model_input_names
|
||||
names_from_processor = list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
||||
return names_from_processor + ["second_per_grid_ts"]
|
||||
|
||||
|
||||
AutoProcessor.register("OvisProcessor", OvisProcessor)
|
||||
Loading…
x
Reference in New Issue
Block a user