[MODEL ADDITION] Ovis2 Model Addition (#15826)

Signed-off-by: Marco <121761685+mlinmg@users.noreply.github.com>
Signed-off-by: Isotr0py <2037008807@qq.com>
Signed-off-by: isotr0py <2037008807@qq.com>
Co-authored-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Marco 2025-04-30 09:33:29 +02:00 committed by GitHub
parent be633fba0f
commit 54072f315f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 1349 additions and 7 deletions

View File

@ -1014,6 +1014,13 @@ See [this page](#generative-models) for more information on how to use generativ
*
* ✅︎
* ✅︎
- * `Ovis2ForConditionalGeneration`<sup>^</sup>
* Ovis2
* T + I<sup>+</sup>
* `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis2-2B`, etc.
*
*
* ✅︎
- * `PaliGemmaForConditionalGeneration`
* PaliGemma, PaliGemma 2
* T + I<sup>E</sup>

View File

@ -725,6 +725,36 @@ def run_nvlm_d(questions: list[str], modality: str) -> ModelRequestData:
)
# Ovis2
def run_ovis2(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
model_name = "AIDC-AI/Ovis2-1B"
tokenizer = "Isotr0py/Ovis2-tokenizer"
engine_args = EngineArgs(
model=model_name,
tokenizer=tokenizer,
max_model_len=4096,
max_num_seqs=2,
trust_remote_code=True,
dtype="half",
hf_overrides={"architectures": ["Ovis2ForConditionalGeneration"]},
limit_mm_per_prompt={"image": 1},
)
placeholder = "<image>\n"
prompts = [("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
f"<|im_start|>user\n{placeholder}"
f"{question}<|im_end|>\n"
"<|im_start|>assistant\n") for question in questions]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# PaliGemma
def run_paligemma(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
@ -1041,6 +1071,7 @@ model_example_map = {
"llama4": run_llama4,
"molmo": run_molmo,
"NVLM_D": run_nvlm_d,
"ovis2": run_ovis2,
"paligemma": run_paligemma,
"paligemma2": run_paligemma2,
"phi3_v": run_phi3v,

View File

@ -436,6 +436,36 @@ def load_nvlm_d(question: str, image_urls: list[str]) -> ModelRequestData:
)
# Ovis2
def load_ovis2(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "AIDC-AI/Ovis2-1B"
tokenizer = "Isotr0py/Ovis2-tokenizer"
engine_args = EngineArgs(
model=model_name,
tokenizer=tokenizer,
max_model_len=8192,
max_num_seqs=2,
trust_remote_code=True,
dtype="half",
limit_mm_per_prompt={"image": len(image_urls)},
hf_overrides={"architectures": ["Ovis2ForConditionalGeneration"]},
)
placeholder = '\n'.join(
[f'Image {i+1}: <image>' for i in range(len(image_urls))]) + '\n'
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
f"<|im_start|>user\n{placeholder}"
f"{question}<|im_end|>\n"
"<|im_start|>assistant\n")
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image_data=[fetch_image(url) for url in image_urls],
)
def load_pixtral_hf(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "mistral-community/pixtral-12b"
@ -685,6 +715,7 @@ model_example_map = {
"mistral3": load_mistral3,
"mllama": load_mllama,
"NVLM_D": load_nvlm_d,
"ovis2": load_ovis2,
"phi3_v": load_phi3v,
"phi4_mm": load_phi4mm,
"pixtral_hf": load_pixtral_hf,

View File

@ -467,6 +467,18 @@ VLM_TEST_SETTINGS = {
max_num_seqs=2,
patch_hf_runner=model_utils.molmo_patch_hf_runner,
),
"ovis2": VLMTestInfo(
models=["AIDC-AI/Ovis2-1B"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
img_idx_to_prompt=lambda idx: "<image>\n", # noqa: E501
max_model_len=4096,
max_num_seqs=2,
dtype="half",
# use sdpa mode for hf runner since ovis2 didn't work with flash_attn
hf_model_kwargs={"llm_attn_implementation": "sdpa"},
patch_hf_runner=model_utils.ovis2_patch_hf_runner,
),
"phi3v": VLMTestInfo(
models=["microsoft/Phi-3.5-vision-instruct"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),

View File

@ -67,7 +67,7 @@ def run_test(
"disable_mm_preprocessor_cache": True,
}
if model_info.tokenizer:
vllm_runner_kwargs_["tokenizer"] = model_info.tokenizer
vllm_runner_kwargs_["tokenizer_name"] = model_info.tokenizer
if model_info.tokenizer_mode:
vllm_runner_kwargs_["tokenizer_mode"] = model_info.tokenizer_mode
if model_info.hf_overrides:

View File

@ -676,3 +676,33 @@ def molmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
hf_model.model.generate = types.MethodType(_generate, hf_model.model)
return hf_model
def ovis2_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for Ovis2."""
hf_model.model.visual_tokenizer.to(hf_model.dtype)
hf_model.model.vte.to(hf_model.dtype)
hf_model.model.llm.to(hf_model.dtype)
hf_model.model.get_output_embeddings = lambda: \
hf_model.model.llm.get_output_embeddings()
def processor(*args, text="", images=None, **kwargs):
text_tokenizer = hf_model.model.get_text_tokenizer()
images = [images] if isinstance(images, Image) else images
text = text.split("<|im_start|>user\n")[1].split("<|im_end|>\n")[0]
prompt, input_ids, pixel_values = hf_model.model.preprocess_inputs(
text_or_conversations=text, images=images)
attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id)
inputs = {
"inputs": input_ids.unsqueeze(0),
"pixel_values": pixel_values.unsqueeze(0),
"attention_mask": attention_mask.unsqueeze(0),
}
return BatchFeature(data=inputs, tensor_type="pt")
hf_model.processor = processor
return hf_model

View File

@ -274,6 +274,7 @@ def _test_processing_correctness_mistral(
"allenai/Molmo-7B-D-0924",
"allenai/Molmo-7B-O-0924",
"nvidia/NVLM-D-72B",
"AIDC-AI/Ovis2-1B",
"google/paligemma-3b-mix-224",
"google/paligemma2-3b-ft-docci-448",
"microsoft/Phi-4-multimodal-instruct",

View File

@ -348,6 +348,10 @@ _MULTIMODAL_EXAMPLE_MODELS = {
max_transformers_version="4.48",
transformers_version_reason="Use of deprecated imports which have been removed.", # noqa: E501
extras={"phi3.5": "microsoft/Phi-3.5-vision-instruct"}), # noqa: E501
"Ovis2ForConditionalGeneration": _HfExamplesInfo("AIDC-AI/Ovis2-1B",
tokenizer="Isotr0py/Ovis2-tokenizer",
trust_remote_code=True,
hf_overrides={"architectures": ["Ovis2ForConditionalGeneration"]}), # noqa: E501
"Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct",
trust_remote_code=True),
"PixtralForConditionalGeneration": _HfExamplesInfo("mistralai/Pixtral-12B-2409", # noqa: E501

View File

@ -496,9 +496,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
if model_type.startswith("llava"):
return self._cached_token_str(self._tokenizer,
hf_config.image_token_index)
if model_type in ("aya_vision", "chameleon", "deepseek_vl_v2",
"internvl_chat", "skywork_chat", "NVLM_D",
"h2ovl_chat", "idefics3", "smolvlm"):
"internvl_chat", "ovis2", "skywork_chat",
"NVLM_D", "h2ovl_chat", "idefics3", "smolvlm"):
return "<image>"
if model_type in ("mllama", "llama4"):
return "<|image|>"

View File

@ -0,0 +1,322 @@
# SPDX-License-Identifier: Apache-2.0
# A modified implementation of the AIMv2 Transformer
# inserted here also the image tokenizer used by Ovis2
from typing import Optional
import torch
from torch import nn, softmax
from torch.nn import functional as F
from torch.nn.functional import gumbel_softmax, pad
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.transformers_utils.configs.ovis2 import (AIMv2Config,
Aimv2VisualTokenizerConfig)
IMAGE_INDICATOR_IDS = [-301, -302, -303, -304,
-305] # kept for vocab prefixed tokens
def st_argmax(y_soft: torch.Tensor, dim: int): # straight-through softmax
index = y_soft.max(dim, keepdim=True)[1]
y_hard = torch.zeros_like(
y_soft, memory_format=torch.legacy_contiguous_format).scatter_(
dim, index, 1.0)
ret = y_hard - y_soft.detach() + y_soft
return ret
class Aimv2VisualTokenizer(torch.nn.Module):
def __init__(self,
config: Aimv2VisualTokenizerConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
**kwargs):
super().__init__()
self.config = config
self.backbone = AIMv2Model(
config=config.backbone_config, # noqa
quant_config=quant_config,
prefix=f"{prefix}.visual_tokenizer")
# reserved tokens for IMAGE_INDICATORS
head_dim = config.vocab_size - len(IMAGE_INDICATOR_IDS)
self.head = torch.nn.Sequential(
ReplicatedLinear(
config.backbone_config.hidden_size * config.hidden_stride *
config.hidden_stride,
head_dim,
bias=False,
), torch.nn.LayerNorm(head_dim))
@property
def dtype(self):
return self.backbone.dtype
@property
def device(self):
return self.backbone.device
def tokenize(self, logits):
if self.config.tokenize_function == 'softmax':
tokens = softmax(logits, dim=-1)
elif self.config.tokenize_function == 'gumbel_argmax':
tokens = gumbel_softmax(logits, tau=self.config.tau, hard=True)
elif self.config.tokenize_function == 'st_argmax':
tokens = st_argmax(logits, dim=-1)
else:
raise ValueError(
'Invalid `max_type`, expected softmax or gumbel_argmax '
f'or st_argmax, but got {self.config.tokenize_function}')
return tokens
def encode(self, pixel_values):
features = self.backbone(pixel_values)
if self.config.drop_cls_token:
features = features[:, 1:, :]
# merge number of `hidden_stride * hidden_stride` hidden states together
# to reduce token sequence length
# e.g., for hidden_stride=2, this leads to a token length reduction:
# 1024 -> 256 for aimv2
if self.config.hidden_stride > 1:
# this `d` maybe different from the above `d``
n, L, d = features.shape
sqrt_l = int(L**0.5)
assert sqrt_l**2 == L, (
"The token sequence length should be a perfect square.")
features = features.reshape(n, sqrt_l, sqrt_l, d)
pl = (self.config.hidden_stride -
(sqrt_l %
self.config.hidden_stride)) % self.config.hidden_stride
features = pad(features, (0, 0, 0, pl, 0, pl), "constant", 0)
sqrt_l += pl
features = features.reshape(n, sqrt_l // self.config.hidden_stride,
self.config.hidden_stride,
sqrt_l // self.config.hidden_stride,
self.config.hidden_stride, d)
# [n, sqrt_l/hs, sqrt_l/hs, hs, hs, d]
features = features.permute(0, 1, 3, 2, 4, 5)
# [n, sqrt_l/hs, sqrt_l/hs, hs*hs*d]
features = features.flatten(3)
# [n, sqrt_l/hs*sqrt_l/hs, hs*hs*d]
features = features.reshape(
n, -1,
self.config.hidden_stride * self.config.hidden_stride * d)
return features
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""[BatchSize, ImageShape] -> [BatchSize, Token, VocabSize]"""
features = self.encode(pixel_values)
logits, _ = self.head[0](
features) # we spllit the sequncial here for not throwing an error
logits = self.head[1](logits)
tokens = self.tokenize(logits)
# tokens' shape is [BatchSize, #Token, VocabSize-5], so padding with
# [BatchSize, #Token, 5], after which, tokens' shape should become
# [BatchSize, #Token, VocabSize]
batch_size, token_len, _ = tokens.shape
padding_tensor = torch.zeros(size=(batch_size, token_len,
len(IMAGE_INDICATOR_IDS)),
dtype=tokens.dtype,
device=tokens.device,
layout=tokens.layout,
requires_grad=False)
tokens = torch.cat((tokens, padding_tensor), dim=2)
return tokens
class AIMv2SwiGLUFFN(nn.Module):
def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig,
prefix: str):
super().__init__()
hidden_features = config.intermediate_size
in_features = config.hidden_size
bias = config.use_bias
# TODO(Isotr0py): investigate if we can add TP to visual tokenizer
self.fc1 = ReplicatedLinear(in_features,
hidden_features,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.fc1")
self.fc2 = ReplicatedLinear(hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.fc2")
self.fc3 = ReplicatedLinear(in_features,
hidden_features,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.fc3")
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_parallel, _ = self.fc1(x)
gate, _ = self.fc3(x)
x_parallel = F.silu(x_parallel) * gate
out, _ = self.fc2(x_parallel)
return out
class AIMv2PatchEmbed(nn.Module):
def __init__(self, config: AIMv2Config):
super().__init__()
self.proj = nn.Conv2d(
config.num_channels,
config.hidden_size,
kernel_size=(config.patch_size, config.patch_size),
stride=(config.patch_size, config.patch_size),
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x).flatten(2).transpose(1, 2)
x = self.norm.forward_native(x)
return x
class AIMv2ViTPreprocessor(nn.Module):
def __init__(self, config: AIMv2Config):
super().__init__()
num_patches = (config.image_size // config.patch_size)**2
self.patchifier = AIMv2PatchEmbed(config)
self.pos_embed = nn.Parameter(
torch.zeros((1, num_patches, config.hidden_size)))
def forward(self, x: torch.Tensor) -> torch.Tensor:
tokens = self.patchifier(x)
_, N, _ = tokens.shape
pos_embed = self.pos_embed.to(tokens.device)
tokens = tokens + pos_embed[:, :N]
return tokens
class AIMv2Attention(nn.Module):
def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig,
prefix: str):
super().__init__()
dim = config.hidden_size
# TODO(Isotr0py): investigate if we can add TP to visual tokenizer
self.num_heads = config.num_attention_heads
self.qkv = ReplicatedLinear(dim, dim * 3, bias=config.qkv_bias)
# self.qkv = QKVParallelLinear(
# hidden_size=dim,
# head_size=dim // config.num_attention_heads,
# total_num_heads=config.num_attention_heads,
# bias=config.qkv_bias,
# quant_config=quant_config,
# prefix=f"{prefix}.qkv")
self.proj = ReplicatedLinear(dim, dim, bias=config.use_bias)
# self.proj = RowParallelLinear(input_size=dim,
# output_size=dim,
# bias = config.use_bias,
# quant_config=quant_config,
# prefix=f"{prefix}.proj")
def forward( # todo might implement multiple attn implementations
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
B, N, C = x.shape
qkv, _ = self.qkv(x)
qkv = qkv.reshape(B, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
x = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
x = x.transpose(1, 2).contiguous().reshape(B, N, C)
x, _ = self.proj(x)
return x
class AIMv2Block(nn.Module):
def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig,
prefix: str):
super().__init__()
self.attn = AIMv2Attention(config,
quant_config=quant_config,
prefix=f"{prefix}.attn")
self.norm_1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp = AIMv2SwiGLUFFN(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.norm_2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
x = x + self.attn(self.norm_1.forward_native(x), mask)
x = x + self.mlp(self.norm_2.forward_native(x))
return x
class AIMv2Transformer(nn.Module):
def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig,
prefix: str):
super().__init__()
self.blocks = nn.ModuleList([
AIMv2Block(config, quant_config, prefix=f"{prefix}.blocks.{i}")
for i in range(config.num_hidden_layers)
])
self.post_trunk_norm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
tokens: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# they take the -1 as the ref embeddings, like a clip skip
for block in self.blocks:
tokens = block(tokens, mask)
# NO NORM IN THE OG IMPLEMENTATION
# tokens = self.post_trunk_norm(tokens)
return tokens
class AIMv2Model(torch.nn.Module):
def __init__(self,
config: AIMv2Config,
quant_config: QuantizationConfig,
prefix: str = ""):
super().__init__()
self.preprocessor = AIMv2ViTPreprocessor(config)
self.trunk = AIMv2Transformer(config,
quant_config=quant_config,
prefix=f"{prefix}.trunk")
@property
def dtype(self):
return self.trunk.blocks[0].attn.qkv.weight.dtype
@property
def device(self):
return self.trunk.blocks[0].attn.qkv.device
def forward(
self,
pixel_values: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
x = self.preprocessor(pixel_values)
x = self.trunk(x, mask)
return x

View File

@ -0,0 +1,331 @@
# SPDX-License-Identifier: Apache-2.0
# adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/ovis/modeling_ovis.py
# Copyright 2023 The vLLM team.
# Copyright 2023 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Ovis2 model."""
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union)
import torch
import torch.nn as nn
from torch import Tensor
from transformers import BatchFeature
from vllm.config import VllmConfig
from vllm.model_executor.models.aimv2 import Aimv2VisualTokenizer
from vllm.model_executor.models.utils import (AutoWeightsLoader, flatten_bn,
init_vllm_registered_model,
maybe_prefix)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs)
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ovis2 import OvisConfig
from vllm.transformers_utils.processors.ovis2 import OvisProcessor
from .interfaces import MultiModalEmbeddings, SupportsMultiModal
from .utils import merge_multimodal_embeddings
# Cannot find the following number from hf config.
IMAGE_TOKEN = "<image>"
IMAGE_ATOM_TOKEN_ID = 151666
IMAGE_PAD_TOKEN_ID = 151672
NUMBER_OF_TOKEN_TO_RESERVE_FOR_SEGMENT = 256
class Ovis2ImagePatchInputs(TypedDict):
type: Literal["image_patches"]
flat_data: torch.Tensor
"""
Shape:
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
"""
patches_per_image: List[int]
"""
List of number of total patches for each image in the batch.
This is used to restore the first two dimensions of `flat_data`.
"""
class VisualEmbedding(torch.nn.Embedding):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, visual_tokens: Tensor) -> Tensor:
if visual_tokens.dtype in [
torch.int8, torch.int16, torch.int32, torch.int64, torch.long
]:
return super().forward(visual_tokens)
return torch.matmul(visual_tokens, self.weight)
@property
def device(self):
return self.weight.device
@property
def dtype(self):
return self.weight.dtype
class Ovis2ProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(OvisConfig)
def get_hf_processor(self, **kwargs):
return self.ctx.get_hf_processor(OvisProcessor)
def get_image_processor(self) -> OvisProcessor:
return self.get_hf_processor().image_processor # type: ignore
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return { # 32k is model token limit at the moment
"image":
self.get_hf_config().multimodal_max_length //
((9 + 1) * NUMBER_OF_TOKEN_TO_RESERVE_FOR_SEGMENT)
}
def get_image_size_with_most_features(self) -> ImageSize:
image_processor = self.get_image_processor()
return ImageSize(width=image_processor.size['shortest_edge'] * 9 * 2,
height=image_processor.size['shortest_edge'] * 9 * 2)
class Ovis2DummyInputsBuilder(BaseDummyInputsBuilder[Ovis2ProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
return IMAGE_TOKEN * num_images
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
target_width, target_height = \
self.info.get_image_size_with_most_features()
mm_data = {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images),
}
return mm_data
class Ovis2MultiModalProcessor(BaseMultiModalProcessor[Ovis2ProcessingInfo]):
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
if not mm_data:
# # Avoid warning from HF logger for text-only input
prompt_ids = self.info.get_tokenizer().encode(prompt)
# prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) nope
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
return processed_outputs
def _apply_hf_processor_tokens_only(
self,
prompt_tokens: list[int],
) -> list[int]:
return prompt_tokens
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values=MultiModalFieldConfig.batched("image"),
grids=MultiModalFieldConfig.batched("image"))
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
def get_replacement_ovis(item_idx):
grid = out_mm_kwargs["grids"][item_idx]
hf_processor = self.info.get_hf_processor()
return hf_processor.construct_image_placeholders(grid)
return [
PromptReplacement(
modality="image",
target=IMAGE_TOKEN,
replacement=get_replacement_ovis,
),
]
@MULTIMODAL_REGISTRY.register_processor(Ovis2MultiModalProcessor,
info=Ovis2ProcessingInfo,
dummy_inputs=Ovis2DummyInputsBuilder)
class Ovis2ForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config: OvisConfig = config
self.llm = init_vllm_registered_model(
vllm_config=vllm_config.with_hf_config(config.get_text_config()),
prefix=maybe_prefix(prefix, "llm"),
)
self.visual_tokenizer = Aimv2VisualTokenizer(
config=config.visual_tokenizer_config,
quant_config=quant_config,
prefix=f"{prefix}.visual_tokenizer",
image_processor_name_or_path=config.visual_tokenizer_config.
backbone_config.name_or_path,
)
self.vte = VisualEmbedding(
self.config.visual_tokenizer_config.vocab_size,
self.config.hidden_size)
# TODO(Isotr0py): PP support
# self.make_empty_intermediate_tensors = (
# self.language_model.make_empty_intermediate_tensors)
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Ovis2ImagePatchInputs]:
pixel_values = kwargs.pop("pixel_values", None)
if pixel_values is None:
return None
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return Ovis2ImagePatchInputs(
type="image_patches",
flat_data=flatten_bn(flatten_bn(pixel_values), concat=True),
patches_per_image=[
x.shape[0] for x in flatten_bn(pixel_values)
],
)
raise AssertionError("This line should be unreachable.")
def _process_image_input(
self, image_input: Ovis2ImagePatchInputs) -> MultiModalEmbeddings:
image_patches_flat = image_input["flat_data"]
patches_per_image = image_input["patches_per_image"]
target_dtype = self.visual_tokenizer.dtype
visual_tokens = self.visual_tokenizer(
image_patches_flat.to(target_dtype))
visual_embeds = self.vte(visual_tokens) # 1:1 numeric eq.
return tuple(
x.flatten(0, 1)
for x in visual_embeds.split(patches_per_image, dim=0))
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
image_features = self._process_image_input(image_input)
return image_features
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.llm.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
[IMAGE_ATOM_TOKEN_ID, IMAGE_PAD_TOKEN_ID])
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
if intermediate_tensors is not None:
inputs_embeds = None
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None
# up until here we have a inputs_embeds 100% numerical identity
# between the OG HF Transformers implementation and ours
hidden_states = self.llm(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.llm.logits_processor(self.llm.lm_head, hidden_states,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
def get_language_model(self) -> torch.nn.Module:
return self.llm

View File

@ -195,6 +195,7 @@ _MULTIMODAL_MODELS = {
"Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"), # noqa: E501
"MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
"NVLM_D": ("nvlm_d", "NVLM_D_Model"),
"Ovis2ForConditionalGeneration": ("ovis2", "Ovis2ForConditionalGeneration"),
"PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501

View File

@ -38,9 +38,9 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
MiniMaxVL01Config, MllamaConfig,
MLPSpeculatorConfig, MPTConfig,
NemotronConfig, NVLM_D_Config,
RWConfig, SkyworkR1VChatConfig,
SolarConfig, Telechat2Config,
UltravoxConfig)
OvisConfig, RWConfig,
SkyworkR1VChatConfig, SolarConfig,
Telechat2Config, UltravoxConfig)
# yapf: enable
from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import resolve_obj_by_qualname
@ -79,6 +79,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"minimax_vl_01": MiniMaxVL01Config,
"nemotron": NemotronConfig,
"NVLM_D": NVLM_D_Config,
"ovis": OvisConfig,
"solar": SolarConfig,
"skywork_chat": SkyworkR1VChatConfig,
"telechat": Telechat2Config,

View File

@ -23,6 +23,7 @@ from vllm.transformers_utils.configs.moonvit import MoonViTConfig
from vllm.transformers_utils.configs.mpt import MPTConfig
from vllm.transformers_utils.configs.nemotron import NemotronConfig
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
from vllm.transformers_utils.configs.ovis2 import OvisConfig
from vllm.transformers_utils.configs.skyworkr1v import SkyworkR1VChatConfig
from vllm.transformers_utils.configs.solar import SolarConfig
from vllm.transformers_utils.configs.telechat2 import Telechat2Config
@ -49,6 +50,7 @@ __all__ = [
"KimiVLConfig",
"NemotronConfig",
"NVLM_D_Config",
"OvisConfig",
"SkyworkR1VChatConfig",
"SolarConfig",
"Telechat2Config",

View File

@ -0,0 +1,170 @@
# SPDX-License-Identifier: Apache-2.0
# yapf: disable
# ruff: noqa: E501
# copied from https://huggingface.co/AIDC-AI/Ovis2-1B/blob/main/configuration_aimv2.py
# and https://huggingface.co/AIDC-AI/Ovis2-1B/blob/main/configuration_ovis.py
from typing import Any, Optional, Union
from transformers import AutoConfig, PretrainedConfig
class AIMv2Config(PretrainedConfig):
"""This is the configuration class to store the configuration of an [`AIMv2Model`].
Instantiating a configuration with the defaults will yield a similar configuration
to that of the [apple/aimv2-large-patch14-224](https://huggingface.co/apple/aimv2-large-patch14-224).
Args:
hidden_size: Dimension of the hidden representations.
intermediate_size: Dimension of the SwiGLU representations.
num_hidden_layers: Number of hidden layers in the Transformer.
num_attention_heads: Number of attention heads for each attention layer
in the Transformer.
num_channels: Number of input channels.
image_size: Image size.
patch_size: Patch size.
rms_norm_eps: Epsilon value used for the RMS normalization layer.
attention_dropout: Dropout ratio for attention probabilities.
projection_dropout: Dropout ratio for the projection layer after the attention.
qkv_bias: Whether to add a bias to the queries, keys and values.
use_bias: Whether to add a bias in the feed-forward and projection layers.
kwargs: Keyword arguments for the [`PretrainedConfig`].
"""
model_type: str = "aimv2"
def __init__(
self,
hidden_size: int = 1024,
intermediate_size: int = 2816,
num_hidden_layers: int = 24,
num_attention_heads: int = 8,
num_channels: int = 3,
image_size: int = 224,
patch_size: int = 14,
rms_norm_eps: float = 1e-5,
attention_dropout: float = 0.0,
projection_dropout: float = 0.0,
qkv_bias: bool = False,
use_bias: bool = False,
**kwargs: Any,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.attention_dropout = attention_dropout
self.rms_norm_eps = rms_norm_eps
self.projection_dropout = projection_dropout
self.qkv_bias = qkv_bias
self.use_bias = use_bias
IGNORE_ID = -100
IMAGE_TOKEN_ID = -200
IMAGE_TOKEN = "<image>"
IMAGE_ATOM_ID = -300
IMAGE_INDICATOR_IDS = [-301, -302, -303, -304, -305]
AutoConfig.register("aimv2", AIMv2Config)
# ----------------------------------------------------------------------
# Visual Tokenizer Configuration
# ----------------------------------------------------------------------
class BaseVisualTokenizerConfig(PretrainedConfig):
def __init__(self,
vocab_size=16384,
tokenize_function="softmax",
tau=1.0,
depths=None,
drop_cls_token=False,
backbone_config: Optional[Union[PretrainedConfig,
dict]] = None,
hidden_stride: int = 1,
**kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.tokenize_function = tokenize_function
self.tau = tau
if isinstance(depths, str):
depths = [int(x) for x in depths.split('|')]
self.depths = depths
self.backbone_kwargs = dict[str, Any]()
self.drop_cls_token = drop_cls_token
if backbone_config is not None:
assert isinstance(backbone_config, (PretrainedConfig, dict)), \
f"expect `backbone_config` to be instance of PretrainedConfig or dict, but got {type(backbone_config)} type"
if not isinstance(backbone_config, PretrainedConfig):
model_type = backbone_config['model_type']
backbone_config.pop('model_type')
backbone_config = AutoConfig.for_model(model_type,
**backbone_config)
self.backbone_config = backbone_config
self.hidden_stride = hidden_stride
class Aimv2VisualTokenizerConfig(BaseVisualTokenizerConfig):
model_type = "aimv2_visual_tokenizer"
def __init__(self, **kwargs):
super().__init__(**kwargs)
if self.drop_cls_token:
self.drop_cls_token = False
if self.depths:
assert len(self.depths) == 1
self.backbone_kwargs['num_hidden_layers'] = self.depths[0]
AutoConfig.register("aimv2_visual_tokenizer", Aimv2VisualTokenizerConfig)
# ----------------------------------------------------------------------
# Ovis Configuration
# ----------------------------------------------------------------------
class OvisConfig(PretrainedConfig):
model_type = "ovis"
def __init__(self,
llm_config: Optional[Union[PretrainedConfig, dict]] = None,
visual_tokenizer_config: Optional[Union[PretrainedConfig,
dict]] = None,
multimodal_max_length=8192,
hidden_size=None,
conversation_formatter_class=None,
llm_attn_implementation=None,
disable_tie_weight=False,
**kwargs):
super().__init__(**kwargs)
if llm_config is not None:
assert isinstance(llm_config, (PretrainedConfig, dict)), \
f"expect `llm_config` to be instance of PretrainedConfig or dict, but got {type(llm_config)} type"
if not isinstance(llm_config, PretrainedConfig):
model_type = llm_config['model_type']
llm_config.pop('model_type')
llm_config = AutoConfig.for_model(model_type, **llm_config)
# map llm_config to text_config
self.text_config = llm_config
if visual_tokenizer_config is not None:
assert isinstance(visual_tokenizer_config, (PretrainedConfig, dict)), \
f"expect `visual_tokenizer_config` to be instance of PretrainedConfig or dict, but got {type(visual_tokenizer_config)} type"
if not isinstance(visual_tokenizer_config, PretrainedConfig):
model_type = visual_tokenizer_config['model_type']
visual_tokenizer_config.pop('model_type')
visual_tokenizer_config = AutoConfig.for_model(
model_type, **visual_tokenizer_config)
self.visual_tokenizer_config = visual_tokenizer_config
self.multimodal_max_length = multimodal_max_length
self.hidden_size = hidden_size
self.conversation_formatter_class = conversation_formatter_class
self.llm_attn_implementation = llm_attn_implementation
self.disable_tie_weight = disable_tie_weight

View File

@ -2,5 +2,6 @@
from vllm.transformers_utils.processors.deepseek_vl2 import (
DeepseekVLV2Processor)
from vllm.transformers_utils.processors.ovis2 import OvisProcessor
__all__ = ["DeepseekVLV2Processor"]
__all__ = ["DeepseekVLV2Processor", "OvisProcessor"]

View File

@ -0,0 +1,397 @@
# SPDX-License-Identifier: Apache-2.0
# yapf: disable
# ruff: noqa: E501
# coding=utf-8
# adapted from https://github.com/AIDC-AI/Ovis/blob/35ab51a1a1e3542fa6db260a1084cefbc8f164bb/ovis/vllm/processing_ovis.py
# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union
import PIL
import torch
from transformers import AutoProcessor, BatchFeature
from transformers.image_utils import ImageInput
from transformers.processing_utils import (ProcessingKwargs, ProcessorMixin,
Unpack)
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
__all__ = [ 'OvisProcessor']
IGNORE_ID = -100
class OvisProcessorKwargs(ProcessingKwargs, total=False): # type: ignore[call-arg]
_defaults = {
"text_kwargs": {
"padding": False,
},
"images_kwargs": {
'max_partition':9,
'covering_threshold':0.9,
'convert_to_rgb':True,
'return_tensors':'pt'},
}
class OvisProcessor(ProcessorMixin):
r"""
Constructs a Ovis processor which wraps a Ovis image processor and a Qwen2 tokenizer into a single processor.
[`OvisProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the
[`~OvisProcessor.__call__`] and [`~OvisProcessor.decode`] for more information.
Args:
image_processor ([`Qwen2VLImageProcessor`], *optional*):
The image processor is a required input.
tokenizer ([`Qwen2TokenizerFast`], *optional*):
The tokenizer is a required input.
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
"""
attributes = ["image_processor", "tokenizer"]
valid_kwargs = ["chat_template"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "Qwen2Tokenizer"
def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
super().__init__(image_processor, tokenizer, chat_template=chat_template)
self.extra_special_tokens = {
"image_token": "<image>",
"image_atom": "<image_atom>",
"image_start": "<img>",
"image_prefix": "<pre>",
"image_col_sep": "<col>",
"image_row_sep": "<row>",
"image_end": "</img>",
'image_pad': '<image_pad>',
}
def __call__(
self,
images: ImageInput = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
**kwargs: Unpack[OvisProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
- **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
- **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
- **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
- **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`.
"""
output_kwargs = self._merge_kwargs(
OvisProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
# Process all images first
image_features = {}
if images is not None:
processed_images = []
image_placeholders_list = []
grids = []
# Process each image
for image in images if isinstance(images, list) else [images]:
pixel_values, image_placeholders, grid = self.preprocess_image(
image=image, **output_kwargs["images_kwargs"]
)
processed_images.append(pixel_values)
image_placeholders_list.append(image_placeholders)
grids.append(grid)
# assign all processed images
if processed_images:
image_features["image_placeholders"] = image_placeholders_list
# Process text input
if text is not None:
if not isinstance(text, list):
text = [text]
tokenized_batched_text = self.tokenizer.batch_encode_plus(
text,
**output_kwargs["text_kwargs"]
)
image_token_id = self.get_token_value("image_token")
replaced_ids_list = []
replaced_attn_mask_list = []
idx = 0
for ids_tensor, attn_mask in zip(tokenized_batched_text['input_ids'],
tokenized_batched_text['attention_mask']):
if image_token_id in ids_tensor and "image_placeholders" in image_features:
if idx < len(image_features["image_placeholders"]):
# Converts in list for ease of use
ids_list = ids_tensor.tolist()
attn_list = attn_mask.tolist()
new_ids = []
new_attn = []
# replace placeholders
for i, token_id in enumerate(ids_list):
if token_id == image_token_id:
placeholder_ids = image_features["image_placeholders"][idx]
new_ids.extend(placeholder_ids)
new_attn.extend([1] * len(placeholder_ids))
idx += 1
else:
new_ids.append(token_id)
new_attn.append(attn_list[i])
# Converts back to tensors
ids_tensor = torch.tensor(new_ids, dtype=torch.long)
attn_mask = torch.tensor(new_attn, dtype=torch.long)
else:
raise RuntimeError(
'Mismatch between the images you provided and the number of placeholder present in the text')
replaced_ids_list.append(ids_tensor)
replaced_attn_mask_list.append(attn_mask)
if replaced_ids_list:
replaced_and_tokenized_ids = torch.stack(replaced_ids_list)
replaced_and_tokenized_attn_mask = torch.stack(replaced_attn_mask_list)
else:
replaced_and_tokenized_ids = torch.tensor([], dtype=torch.long)
replaced_and_tokenized_attn_mask = torch.tensor([], dtype=torch.long)
# Create the output with text features
output = BatchFeature(
data={
"input_ids": replaced_and_tokenized_ids,
"attention_mask": replaced_and_tokenized_attn_mask,
}
)
# Add image features if present
if image_features:
output["pixel_values"] = processed_images
output['grids'] = grids
return output
# If only images were provided
return BatchFeature(data=image_features)
def get_image_size(self):
height = self.image_processor.crop_size["height"]
width = self.image_processor.crop_size["width"]
return height, width
def get_token_value(self, tok):
return self.tokenizer.get_vocab()[self.extra_special_tokens[tok]]
def construct_image_placeholders(self, grid):
image_placeholders = [self.get_token_value('image_start'),
self.get_token_value('image_atom'),
self.get_token_value('image_prefix')]
if grid[0] * grid[1] > 1:
for r in range(grid[0]):
for c in range(grid[1]):
image_placeholders.append(self.get_token_value('image_atom') )
if c < grid[1] - 1:
image_placeholders.append(self.get_token_value('image_col_sep'))
if r < grid[0] - 1:
image_placeholders.append(self.get_token_value('image_row_sep'))
image_placeholders.append(self.get_token_value('image_end'))
# return image_placeholders
image_atom_token_id = self.get_token_value('image_atom')
# Extract the padding token ID from tokenizer
image_padding_token_id = self.get_token_value('image_pad')
# Create a new list with padding tokens inserted
padded_placeholder_tokens = []
for token in image_placeholders:
padded_placeholder_tokens.append(token)
if token == image_atom_token_id:
# Add 255 padding tokens after each image atom token
padded_placeholder_tokens.extend([image_padding_token_id] * 255)
return padded_placeholder_tokens
def preprocess_image(self, image: PIL.Image.Image, max_partition, covering_threshold, convert_to_rgb, return_tensors):
def _preprocess(img: PIL.Image.Image, side):
# first resize and preprocess
w, h = img.size
if w == h:
new_width = new_height = side
elif w > h:
new_width = side
new_height = int(h / w * new_width)
else:
new_height = side
new_width = int(w / h * new_height)
new_size = dict(height=new_height, width=new_width)
pixel_values = self.image_processor.preprocess(img, size=new_size, return_tensors=return_tensors)['pixel_values']
# then pad to square
square_values = torch.zeros([1, 3, side, side], dtype=pixel_values.dtype, device=pixel_values.device)
new_height, new_width = pixel_values.shape[2:]
if new_height == new_width:
square_values[:, :, :, :] = pixel_values
elif new_height > new_width:
from_index = (side - new_width) // 2
square_values[:, :, :, from_index:from_index + new_width] = pixel_values
else:
from_index = (side - new_height) // 2
square_values[:, :, from_index:from_index + new_height, :] = pixel_values
return square_values
def _partition(img, grid) -> list[tuple[int, int, int, int]]:
w, h = img.size
row_height = h // grid[0]
col_width = w // grid[1]
partition = []
for row in range(grid[0]):
for col in range(grid[1]):
left = col * col_width
upper = row * row_height
right = w if col == grid[1] - 1 else (col + 1) * col_width
lower = h if row == grid[0] - 1 else (row + 1) * row_height
partition.append((left, upper, right, lower))
return partition
def _covering_area(left, upper, right, lower, side):
w = right - left
h = lower - upper
w, h = max(w, h), min(w, h)
if w > side:
h = h / w * side
w = side
return w * h
def _get_best_grid(img, side):
img_area = img.size[0] * img.size[1]
candidate_grids = []
for i in range(1, max_partition + 1):
for j in range(1, max_partition + 1):
if i * j <= max_partition:
candidate_grids.append((i, j))
all_grids = []
good_grids = []
for grid in candidate_grids:
partition = _partition(img, grid)
covering_ratio = sum([_covering_area(*p, side) for p in partition]) / img_area
assert covering_ratio <= 1.0
all_grids.append((grid, covering_ratio))
if covering_ratio > covering_threshold:
good_grids.append((grid, covering_ratio))
if len(good_grids) > 0:
# pick the good partition with minimum #sub_images and break the tie using covering_ratio
return sorted(good_grids, key=lambda x: (x[0][0] * x[0][1], -x[1]))[0][0]
else:
# pick the partition with maximum covering_ratio and break the tie using #sub_images
return sorted(all_grids, key=lambda x: (-x[1], x[0][0] * x[0][1]))[0][0]
if convert_to_rgb and image.mode != 'RGB':
image = image.convert('RGB')
sides = self.get_image_size()
if sides[0] != sides[1]:
raise ValueError('get_image_size() returns non-square size')
side = sides[0]
grid = _get_best_grid(image, side)
partition = _partition(image, grid)
crops = [image.crop(p) for p in partition]
if len(crops) > 1:
crops.insert(0, image)
pixel_values = torch.cat([_preprocess(crop, side) for crop in crops], dim=0)
image_placeholders = self.construct_image_placeholders(grid)
return pixel_values, image_placeholders, grid
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.
Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.
Returns:
`List[str]`: The decoded text.
"""
return self.tokenizer.batch_decode(
generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
names_from_processor = list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
return names_from_processor + ["second_per_grid_ts"]
AutoProcessor.register("OvisProcessor", OvisProcessor)