diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md
index 8c6e7b04de85..48fc24f3447a 100644
--- a/docs/source/models/supported_models.md
+++ b/docs/source/models/supported_models.md
@@ -1045,10 +1045,10 @@ Specified using `--task generate`.
*
* ✅︎
* ✅︎
-- * `Ovis2ForConditionalGeneration`^
- * Ovis2
+- * `Ovis`
+ * Ovis2, Ovis1.6
* T + I+
- * `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis2-2B`, etc.
+ * `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc.
*
*
* ✅︎
diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py
index 5c173ab1abb9..c54f328c7a38 100644
--- a/examples/offline_inference/vision_language.py
+++ b/examples/offline_inference/vision_language.py
@@ -725,8 +725,8 @@ def run_nvlm_d(questions: list[str], modality: str) -> ModelRequestData:
)
-# Ovis2
-def run_ovis2(questions: list[str], modality: str) -> ModelRequestData:
+# Ovis
+def run_ovis(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
model_name = "AIDC-AI/Ovis2-1B"
@@ -737,15 +737,18 @@ def run_ovis2(questions: list[str], modality: str) -> ModelRequestData:
max_num_seqs=2,
trust_remote_code=True,
dtype="half",
- hf_overrides={"architectures": ["Ovis2ForConditionalGeneration"]},
limit_mm_per_prompt={modality: 1},
)
- placeholder = "\n"
- prompts = [("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
- f"<|im_start|>user\n{placeholder}"
- f"{question}<|im_end|>\n"
- "<|im_start|>assistant\n") for question in questions]
+ tokenizer = AutoTokenizer.from_pretrained(model_name,
+ trust_remote_code=True)
+ messages = [[{
+ 'role': 'user',
+ 'content': f"\n{question}"
+ }] for question in questions]
+ prompts = tokenizer.apply_chat_template(messages,
+ tokenize=False,
+ add_generation_prompt=True)
return ModelRequestData(
engine_args=engine_args,
@@ -1069,7 +1072,7 @@ model_example_map = {
"llama4": run_llama4,
"molmo": run_molmo,
"NVLM_D": run_nvlm_d,
- "ovis2": run_ovis2,
+ "ovis": run_ovis,
"paligemma": run_paligemma,
"paligemma2": run_paligemma2,
"phi3_v": run_phi3v,
diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py
index 48d590b05b06..20a8e635e322 100644
--- a/examples/offline_inference/vision_language_multi_image.py
+++ b/examples/offline_inference/vision_language_multi_image.py
@@ -436,8 +436,8 @@ def load_nvlm_d(question: str, image_urls: list[str]) -> ModelRequestData:
)
-# Ovis2
-def load_ovis2(question: str, image_urls: list[str]) -> ModelRequestData:
+# Ovis
+def load_ovis(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "AIDC-AI/Ovis2-1B"
engine_args = EngineArgs(
@@ -447,15 +447,17 @@ def load_ovis2(question: str, image_urls: list[str]) -> ModelRequestData:
trust_remote_code=True,
dtype="half",
limit_mm_per_prompt={"image": len(image_urls)},
- hf_overrides={"architectures": ["Ovis2ForConditionalGeneration"]},
)
- placeholder = '\n'.join(
- [f'Image {i+1}: ' for i in range(len(image_urls))]) + '\n'
- prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
- f"<|im_start|>user\n{placeholder}"
- f"{question}<|im_end|>\n"
- "<|im_start|>assistant\n")
+ placeholders = "\n".join(f"Image-{i}: \n"
+ for i, _ in enumerate(image_urls, start=1))
+ messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
+
+ tokenizer = AutoTokenizer.from_pretrained(model_name,
+ trust_remote_code=True)
+ prompt = tokenizer.apply_chat_template(messages,
+ tokenize=False,
+ add_generation_prompt=True)
return ModelRequestData(
engine_args=engine_args,
@@ -713,7 +715,7 @@ model_example_map = {
"mistral3": load_mistral3,
"mllama": load_mllama,
"NVLM_D": load_nvlm_d,
- "ovis2": load_ovis2,
+ "ovis": load_ovis,
"phi3_v": load_phi3v,
"phi4_mm": load_phi4mm,
"pixtral_hf": load_pixtral_hf,
diff --git a/tests/conftest.py b/tests/conftest.py
index fa979f1093be..c5700179c228 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -355,10 +355,16 @@ class HfRunner:
**model_kwargs,
)
+ # in case some unquantized custom models are not in same dtype
+ if (getattr(model, "quantization_method", None) is None
+ and any(p.dtype != self.dtype
+ for p in model.parameters())):
+ model = model.to(dtype=self.dtype)
+
if (getattr(model, "quantization_method", None) != "bitsandbytes"
and len({p.device
for p in model.parameters()}) < 2):
- model = model.to(self.device)
+ model = model.to(device=self.device)
self.model = model
diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py
index 6e915a9f6005..dead2edc4fa3 100644
--- a/tests/models/multimodal/generation/test_common.py
+++ b/tests/models/multimodal/generation/test_common.py
@@ -476,6 +476,31 @@ VLM_TEST_SETTINGS = {
max_num_seqs=2,
patch_hf_runner=model_utils.molmo_patch_hf_runner,
),
+ "ovis1_6-gemma2": VLMTestInfo(
+ models=["AIDC-AI/Ovis1.6-Gemma2-9B"],
+ test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
+ prompt_formatter=lambda img_prompt: f"user\n{img_prompt}\nmodel\n", # noqa: E501
+ img_idx_to_prompt=lambda idx: "\n", # noqa: E501
+ max_model_len=4096,
+ max_num_seqs=2,
+ dtype="half",
+ # use sdpa mode for hf runner since ovis2 didn't work with flash_attn
+ hf_model_kwargs={"llm_attn_implementation": "sdpa"},
+ patch_hf_runner=model_utils.ovis_patch_hf_runner,
+ marks=[large_gpu_mark(min_gb=32)],
+ ),
+ "ovis1_6": VLMTestInfo(
+ models=["AIDC-AI/Ovis1.6-Llama3.2-3B"],
+ test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
+ prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful and honest multimodal assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
+ img_idx_to_prompt=lambda idx: "\n", # noqa: E501
+ max_model_len=4096,
+ max_num_seqs=2,
+ dtype="half",
+ # use sdpa mode for hf runner since ovis2 didn't work with flash_attn
+ hf_model_kwargs={"llm_attn_implementation": "sdpa"},
+ patch_hf_runner=model_utils.ovis_patch_hf_runner,
+ ),
"ovis2": VLMTestInfo(
models=["AIDC-AI/Ovis2-1B"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
@@ -486,7 +511,7 @@ VLM_TEST_SETTINGS = {
dtype="half",
# use sdpa mode for hf runner since ovis2 didn't work with flash_attn
hf_model_kwargs={"llm_attn_implementation": "sdpa"},
- patch_hf_runner=model_utils.ovis2_patch_hf_runner,
+ patch_hf_runner=model_utils.ovis_patch_hf_runner,
),
"phi3v": VLMTestInfo(
models=["microsoft/Phi-3.5-vision-instruct"],
diff --git a/tests/models/multimodal/generation/vlm_utils/model_utils.py b/tests/models/multimodal/generation/vlm_utils/model_utils.py
index f0f4ed989241..e31408d6063f 100644
--- a/tests/models/multimodal/generation/vlm_utils/model_utils.py
+++ b/tests/models/multimodal/generation/vlm_utils/model_utils.py
@@ -678,12 +678,8 @@ def molmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
return hf_model
-def ovis2_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
+def ovis_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for Ovis2."""
- hf_model.model.visual_tokenizer.to(hf_model.dtype)
- hf_model.model.vte.to(hf_model.dtype)
- hf_model.model.llm.to(hf_model.dtype)
-
hf_model.model.get_output_embeddings = lambda: \
hf_model.model.llm.get_output_embeddings()
@@ -691,7 +687,16 @@ def ovis2_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
text_tokenizer = hf_model.model.get_text_tokenizer()
images = [images] if isinstance(images, Image) else images
- text = text.split("<|im_start|>user\n")[1].split("<|im_end|>\n")[0]
+ prompt_start_and_end = {
+ "qwen2": ("<|im_start|>user\n", "<|im_end|>\n"),
+ "llama":
+ ("<|start_header_id|>user<|end_header_id|>\n\n", "<|eot_id|>"),
+ "gemma2": ("user\n", "\n"),
+ }
+ for start, end in prompt_start_and_end.values():
+ if start in text and end in text:
+ text = text.split(start)[1].split(end)[0]
+ break
prompt, input_ids, pixel_values = hf_model.model.preprocess_inputs(
text_or_conversations=text, images=images)
diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py
index 772a2db3e48a..e6b70a4438e9 100644
--- a/tests/models/multimodal/processing/test_common.py
+++ b/tests/models/multimodal/processing/test_common.py
@@ -146,7 +146,8 @@ def _test_processing_correctness_hf(
batch_idx: int,
ignore_mm_keys: Optional[set[str]] = None,
):
- if model_config.hf_config.model_type in ("mllama", "whisper", "ultravox"):
+ if model_config.hf_config.model_type in ("mllama", "ovis", "ultravox",
+ "whisper"):
# For some multimodal models, tokenizer will always add bos_token
# at the beginning of prompt by default, causing hf_processor outputs
# incorrect token ids. So we need use `add_special_tokens=False` here
@@ -274,6 +275,8 @@ def _test_processing_correctness_mistral(
"allenai/Molmo-7B-D-0924",
"allenai/Molmo-7B-O-0924",
"nvidia/NVLM-D-72B",
+ "AIDC-AI/Ovis1.6-Gemma2-9B",
+ "AIDC-AI/Ovis1.6-Llama3.2-3B",
"AIDC-AI/Ovis2-1B",
"google/paligemma-3b-mix-224",
"google/paligemma2-3b-ft-docci-448",
diff --git a/tests/models/registry.py b/tests/models/registry.py
index a1f2edac02b9..683d15d508ec 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -355,9 +355,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
max_transformers_version="4.48",
transformers_version_reason="Use of deprecated imports which have been removed.", # noqa: E501
extras={"phi3.5": "microsoft/Phi-3.5-vision-instruct"}), # noqa: E501
- "Ovis2ForConditionalGeneration": _HfExamplesInfo("AIDC-AI/Ovis2-1B",
- trust_remote_code=True,
- hf_overrides={"architectures": ["Ovis2ForConditionalGeneration"]}), # noqa: E501
+ "Ovis": _HfExamplesInfo("AIDC-AI/Ovis2-1B", trust_remote_code=True,
+ extras={"1.6-llama": "AIDC-AI/Ovis1.6-Llama3.2-3B",
+ "1.6-gemma": "AIDC-AI/Ovis1.6-Gemma2-9B"}), # noqa: E501
"Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct",
trust_remote_code=True),
"PixtralForConditionalGeneration": _HfExamplesInfo("mistralai/Pixtral-12B-2409", # noqa: E501
diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py
index 38fe98572178..db43b2dd295d 100644
--- a/vllm/entrypoints/chat_utils.py
+++ b/vllm/entrypoints/chat_utils.py
@@ -512,7 +512,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
hf_config.image_token_index)
if model_type in ("aya_vision", "chameleon", "deepseek_vl_v2",
- "internvl_chat", "ovis2", "skywork_chat",
+ "internvl_chat", "ovis", "skywork_chat",
"NVLM_D", "h2ovl_chat", "idefics3", "smolvlm"):
return ""
if model_type in ("mllama", "llama4"):
diff --git a/vllm/model_executor/models/aimv2.py b/vllm/model_executor/models/aimv2.py
index 730e770dc3d6..aefd6c973755 100644
--- a/vllm/model_executor/models/aimv2.py
+++ b/vllm/model_executor/models/aimv2.py
@@ -5,129 +5,14 @@
from typing import Optional
import torch
-from torch import nn, softmax
+import torch.nn as nn
from torch.nn import functional as F
-from torch.nn.functional import gumbel_softmax, pad
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
-from vllm.transformers_utils.configs.ovis2 import (AIMv2Config,
- Aimv2VisualTokenizerConfig)
-
-IMAGE_INDICATOR_IDS = [-301, -302, -303, -304,
- -305] # kept for vocab prefixed tokens
-
-
-def st_argmax(y_soft: torch.Tensor, dim: int): # straight-through softmax
- index = y_soft.max(dim, keepdim=True)[1]
- y_hard = torch.zeros_like(
- y_soft, memory_format=torch.legacy_contiguous_format).scatter_(
- dim, index, 1.0)
- ret = y_hard - y_soft.detach() + y_soft
- return ret
-
-
-class Aimv2VisualTokenizer(torch.nn.Module):
-
- def __init__(self,
- config: Aimv2VisualTokenizerConfig,
- quant_config: Optional[QuantizationConfig] = None,
- prefix: str = "",
- **kwargs):
- super().__init__()
- self.config = config
- self.backbone = AIMv2Model(
- config=config.backbone_config, # noqa
- quant_config=quant_config,
- prefix=f"{prefix}.visual_tokenizer")
- # reserved tokens for IMAGE_INDICATORS
- head_dim = config.vocab_size - len(IMAGE_INDICATOR_IDS)
- self.head = torch.nn.Sequential(
- ReplicatedLinear(
- config.backbone_config.hidden_size * config.hidden_stride *
- config.hidden_stride,
- head_dim,
- bias=False,
- ), torch.nn.LayerNorm(head_dim))
-
- @property
- def dtype(self):
- return self.backbone.dtype
-
- @property
- def device(self):
- return self.backbone.device
-
- def tokenize(self, logits):
- if self.config.tokenize_function == 'softmax':
- tokens = softmax(logits, dim=-1)
- elif self.config.tokenize_function == 'gumbel_argmax':
- tokens = gumbel_softmax(logits, tau=self.config.tau, hard=True)
- elif self.config.tokenize_function == 'st_argmax':
- tokens = st_argmax(logits, dim=-1)
- else:
- raise ValueError(
- 'Invalid `max_type`, expected softmax or gumbel_argmax '
- f'or st_argmax, but got {self.config.tokenize_function}')
- return tokens
-
- def encode(self, pixel_values):
- features = self.backbone(pixel_values)
- if self.config.drop_cls_token:
- features = features[:, 1:, :]
-
- # merge number of `hidden_stride * hidden_stride` hidden states together
- # to reduce token sequence length
- # e.g., for hidden_stride=2, this leads to a token length reduction:
- # 1024 -> 256 for aimv2
- if self.config.hidden_stride > 1:
- # this `d` maybe different from the above `d``
- n, L, d = features.shape
- sqrt_l = int(L**0.5)
- assert sqrt_l**2 == L, (
- "The token sequence length should be a perfect square.")
- features = features.reshape(n, sqrt_l, sqrt_l, d)
- pl = (self.config.hidden_stride -
- (sqrt_l %
- self.config.hidden_stride)) % self.config.hidden_stride
- features = pad(features, (0, 0, 0, pl, 0, pl), "constant", 0)
- sqrt_l += pl
- features = features.reshape(n, sqrt_l // self.config.hidden_stride,
- self.config.hidden_stride,
- sqrt_l // self.config.hidden_stride,
- self.config.hidden_stride, d)
- # [n, sqrt_l/hs, sqrt_l/hs, hs, hs, d]
- features = features.permute(0, 1, 3, 2, 4, 5)
- # [n, sqrt_l/hs, sqrt_l/hs, hs*hs*d]
- features = features.flatten(3)
- # [n, sqrt_l/hs*sqrt_l/hs, hs*hs*d]
- features = features.reshape(
- n, -1,
- self.config.hidden_stride * self.config.hidden_stride * d)
-
- return features
-
- def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
- """[BatchSize, ImageShape] -> [BatchSize, Token, VocabSize]"""
- features = self.encode(pixel_values)
- logits, _ = self.head[0](
- features) # we spllit the sequncial here for not throwing an error
- logits = self.head[1](logits)
- tokens = self.tokenize(logits)
- # tokens' shape is [BatchSize, #Token, VocabSize-5], so padding with
- # [BatchSize, #Token, 5], after which, tokens' shape should become
- # [BatchSize, #Token, VocabSize]
- batch_size, token_len, _ = tokens.shape
- padding_tensor = torch.zeros(size=(batch_size, token_len,
- len(IMAGE_INDICATOR_IDS)),
- dtype=tokens.dtype,
- device=tokens.device,
- layout=tokens.layout,
- requires_grad=False)
- tokens = torch.cat((tokens, padding_tensor), dim=2)
- return tokens
+from vllm.transformers_utils.configs.ovis import AIMv2Config
class AIMv2SwiGLUFFN(nn.Module):
@@ -302,14 +187,6 @@ class AIMv2Model(torch.nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.trunk")
- @property
- def dtype(self):
- return self.trunk.blocks[0].attn.qkv.weight.dtype
-
- @property
- def device(self):
- return self.trunk.blocks[0].attn.qkv.device
-
def forward(
self,
pixel_values: torch.Tensor,
diff --git a/vllm/model_executor/models/ovis2.py b/vllm/model_executor/models/ovis.py
similarity index 59%
rename from vllm/model_executor/models/ovis2.py
rename to vllm/model_executor/models/ovis.py
index 67cc86e7fc82..5204c751216f 100644
--- a/vllm/model_executor/models/ovis2.py
+++ b/vllm/model_executor/models/ovis.py
@@ -15,17 +15,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-""" PyTorch Ovis2 model."""
+""" PyTorch Ovis model."""
+import math
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union)
import torch
import torch.nn as nn
from torch import Tensor
-from transformers import BatchFeature
+from torch.nn.functional import gumbel_softmax, pad, softmax
+from transformers import BaseImageProcessor, BatchFeature
from vllm.config import VllmConfig
-from vllm.model_executor.models.aimv2 import Aimv2VisualTokenizer
+from vllm.model_executor.layers.linear import ReplicatedLinear
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
+from vllm.model_executor.models.aimv2 import AIMv2Model
+from vllm.model_executor.models.siglip import SiglipVisionModel
from vllm.model_executor.models.utils import (AutoWeightsLoader, flatten_bn,
init_vllm_registered_model,
maybe_prefix)
@@ -38,19 +44,160 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
-from vllm.transformers_utils.configs.ovis2 import OvisConfig
-from vllm.transformers_utils.processors.ovis2 import OvisProcessor
+from vllm.transformers_utils.configs.ovis import (BaseVisualTokenizerConfig,
+ OvisConfig)
+from vllm.transformers_utils.processors.ovis import OvisProcessor
from .interfaces import MultiModalEmbeddings, SupportsMultiModal
from .utils import merge_multimodal_embeddings
# Cannot find the following number from hf config.
IMAGE_TOKEN = ""
-IMAGE_PAD_TOKEN_ID = 151655
-NUMBER_OF_TOKEN_TO_RESERVE_FOR_SEGMENT = 256
+IMAGE_INDICATOR_IDS = [-301, -302, -303, -304, -305]
+
+IMAGE_PAD_TOKEN_MAP = {
+ "gemma2": "",
+ "llama": "<|reserved_special_token_0|>",
+ "qwen2": "<|image_pad|>",
+}
+IMAGE_PAD_TOKEN_ID_MAP = {
+ "gemma2": 7,
+ "llama": 128002,
+ "qwen2": 151655,
+}
-class Ovis2ImagePatchInputs(TypedDict):
+def st_argmax(y_soft: torch.Tensor, dim: int): # straight-through softmax
+ index = y_soft.argmax(dim, keepdim=True)
+ return torch.zeros_like(
+ y_soft,
+ memory_format=torch.legacy_contiguous_format,
+ ).scatter_(dim, index, 1.0)
+
+
+class VisualTokenizer(torch.nn.Module):
+
+ def __init__(
+ self,
+ config: BaseVisualTokenizerConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ):
+ super().__init__()
+ self.config = config
+ self.backbone = self._init_backbone(
+ config=config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.backbone",
+ )
+ # reserved tokens for IMAGE_INDICATORS
+ head_dim = config.vocab_size - len(IMAGE_INDICATOR_IDS)
+ self.head = torch.nn.Sequential(
+ ReplicatedLinear(
+ config.backbone_config.hidden_size * config.hidden_stride *
+ config.hidden_stride,
+ head_dim,
+ bias=False,
+ return_bias=False,
+ ), torch.nn.LayerNorm(head_dim))
+
+ def _init_backbone(
+ self,
+ config: BaseVisualTokenizerConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ):
+ model_type = config.backbone_config.model_type
+ if model_type == "aimv2":
+ return AIMv2Model(
+ config=config.backbone_config,
+ quant_config=quant_config,
+ prefix=prefix,
+ )
+ elif model_type == "siglip_vision_model":
+ return SiglipVisionModel(
+ config=config.backbone_config,
+ quant_config=quant_config,
+ prefix=prefix,
+ )
+ raise ValueError(
+ f"Unsupported visual tokenizer model_type: {model_type}")
+
+ @property
+ def dtype(self):
+ return next(self.head.parameters()).dtype
+
+ @property
+ def device(self):
+ return next(self.head.parameters()).device
+
+ def tokenize(self, logits):
+ if self.config.tokenize_function == 'softmax':
+ tokens = softmax(logits, dim=-1)
+ elif self.config.tokenize_function == 'gumbel_argmax':
+ tokens = gumbel_softmax(logits, tau=self.config.tau, hard=True)
+ elif self.config.tokenize_function == 'st_argmax':
+ tokens = st_argmax(logits, dim=-1)
+ else:
+ raise ValueError(
+ 'Invalid `max_type`, expected softmax or gumbel_argmax '
+ f'or st_argmax, but got {self.config.tokenize_function}')
+ return tokens
+
+ def encode(self, pixel_values):
+ features = self.backbone(pixel_values)
+ if self.config.drop_cls_token:
+ features = features[:, 1:, :]
+
+ # merge number of `hidden_stride * hidden_stride` hidden states together
+ # to reduce token sequence length
+ # e.g., for hidden_stride=2, this leads to a token length reduction:
+ # 1024 -> 256 for aimv2
+ if self.config.hidden_stride > 1:
+ # this `d` maybe different from the above `d``
+ n, L, d = features.shape
+ sqrt_l = int(L**0.5)
+ assert sqrt_l**2 == L, (
+ "The token sequence length should be a perfect square.")
+ features = features.reshape(n, sqrt_l, sqrt_l, d)
+ pl = (self.config.hidden_stride -
+ (sqrt_l %
+ self.config.hidden_stride)) % self.config.hidden_stride
+ features = pad(features, (0, 0, 0, pl, 0, pl), "constant", 0)
+ sqrt_l += pl
+ features = features.reshape(n, sqrt_l // self.config.hidden_stride,
+ self.config.hidden_stride,
+ sqrt_l // self.config.hidden_stride,
+ self.config.hidden_stride, d)
+ # [n, sqrt_l/hs, sqrt_l/hs, hs, hs, d]
+ features = features.permute(0, 1, 3, 2, 4, 5)
+ # [n, sqrt_l/hs, sqrt_l/hs, hs*hs*d]
+ features = features.flatten(3)
+ # [n, sqrt_l/hs*sqrt_l/hs, hs*hs*d]
+ features = features.reshape(
+ n, -1,
+ self.config.hidden_stride * self.config.hidden_stride * d)
+
+ return features
+
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
+ """[BatchSize, ImageShape] -> [BatchSize, Token, VocabSize]"""
+ features = self.encode(pixel_values)
+ logits = self.head(features)
+ tokens = self.tokenize(logits)
+ # tokens' shape is [BatchSize, #Token, VocabSize-5], so padding with
+ # [BatchSize, #Token, 5], after which, tokens' shape should become
+ # [BatchSize, #Token, VocabSize]
+ tokens = torch.nn.functional.pad(
+ tokens,
+ (0, len(IMAGE_INDICATOR_IDS)),
+ mode="constant",
+ value=0,
+ )
+ return tokens
+
+
+class OvisImagePatchInputs(TypedDict):
type: Literal["image_patches"]
flat_data: torch.Tensor
"""
@@ -92,31 +239,50 @@ class VisualEmbedding(torch.nn.Embedding):
return self.weight.dtype
-class Ovis2ProcessingInfo(BaseProcessingInfo):
+class OvisProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(OvisConfig)
def get_hf_processor(self, **kwargs):
- return self.ctx.get_hf_processor(OvisProcessor)
+ return self.ctx.get_hf_processor(
+ OvisProcessor,
+ image_pad_token=self.get_image_pad_token(),
+ image_segment_len=self.get_image_segment_len(),
+ )
- def get_image_processor(self) -> OvisProcessor:
+ def get_image_segment_len(self) -> int:
+ visual_tokenizer_config = self.get_hf_config().visual_tokenizer_config
+ image_size = visual_tokenizer_config.backbone_config.image_size
+ patch_size = visual_tokenizer_config.backbone_config.patch_size
+ hidden_stride = visual_tokenizer_config.hidden_stride
+ patch_grid_length = math.ceil(image_size / patch_size)
+ assert patch_grid_length % hidden_stride == 0, (
+ f"patch_grid_length {patch_grid_length} is not divisible by "
+ f"hidden_stride {hidden_stride}")
+ # minus 1 for presented image token
+ return (patch_grid_length // hidden_stride)**2 - 1
+
+ def get_image_pad_token(self) -> str:
+ hf_text_config = self.get_hf_config().get_text_config()
+ text_model_type = hf_text_config.model_type
+ return IMAGE_PAD_TOKEN_MAP.get(text_model_type)
+
+ def get_image_processor(self) -> BaseImageProcessor:
return self.get_hf_processor().image_processor # type: ignore
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
- return { # 32k is model token limit at the moment
- "image":
- self.get_hf_config().multimodal_max_length //
- ((9 + 1) * NUMBER_OF_TOKEN_TO_RESERVE_FOR_SEGMENT)
- }
+ return {"image": None}
def get_image_size_with_most_features(self) -> ImageSize:
- image_processor = self.get_image_processor()
- return ImageSize(width=image_processor.size['shortest_edge'] * 9 * 2,
- height=image_processor.size['shortest_edge'] * 9 * 2)
+ height, width = self.get_hf_processor().get_image_size()
+ hs = self.get_hf_config().visual_tokenizer_config.hidden_stride
+ # NOTE(Isotr0py): 9 is `max_partion` hardcoded in original code
+ # https://huggingface.co/AIDC-AI/Ovis2-1B/blob/main/modeling_ovis.py#L96
+ return ImageSize(width=width * hs * 9, height=height * hs * 9)
-class Ovis2DummyInputsBuilder(BaseDummyInputsBuilder[Ovis2ProcessingInfo]):
+class OvisDummyInputsBuilder(BaseDummyInputsBuilder[OvisProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
@@ -141,7 +307,7 @@ class Ovis2DummyInputsBuilder(BaseDummyInputsBuilder[Ovis2ProcessingInfo]):
return mm_data
-class Ovis2MultiModalProcessor(BaseMultiModalProcessor[Ovis2ProcessingInfo]):
+class OvisMultiModalProcessor(BaseMultiModalProcessor[OvisProcessingInfo]):
def image_indicators_to_visual_tokens(
self,
@@ -165,9 +331,9 @@ class Ovis2MultiModalProcessor(BaseMultiModalProcessor[Ovis2ProcessingInfo]):
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
if not mm_data:
- # # Avoid warning from HF logger for text-only input
- prompt_ids = self.info.get_tokenizer().encode(prompt)
- # prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) nope
+ # Avoid warning from HF logger for text-only input
+ tokenizer = self.info.get_tokenizer()
+ prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
processed_outputs = super()._call_hf_processor(
@@ -226,10 +392,10 @@ class Ovis2MultiModalProcessor(BaseMultiModalProcessor[Ovis2ProcessingInfo]):
]
-@MULTIMODAL_REGISTRY.register_processor(Ovis2MultiModalProcessor,
- info=Ovis2ProcessingInfo,
- dummy_inputs=Ovis2DummyInputsBuilder)
-class Ovis2ForConditionalGeneration(nn.Module, SupportsMultiModal):
+@MULTIMODAL_REGISTRY.register_processor(OvisMultiModalProcessor,
+ info=OvisProcessingInfo,
+ dummy_inputs=OvisDummyInputsBuilder)
+class Ovis(nn.Module, SupportsMultiModal):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
@@ -242,24 +408,25 @@ class Ovis2ForConditionalGeneration(nn.Module, SupportsMultiModal):
prefix=maybe_prefix(prefix, "llm"),
)
- self.visual_tokenizer = Aimv2VisualTokenizer(
+ self.visual_tokenizer = VisualTokenizer(
config=config.visual_tokenizer_config,
quant_config=quant_config,
prefix=f"{prefix}.visual_tokenizer",
- image_processor_name_or_path=config.visual_tokenizer_config.
- backbone_config.name_or_path,
)
self.vte = VisualEmbedding(
self.config.visual_tokenizer_config.vocab_size,
self.config.hidden_size)
+ text_model_type = self.config.get_text_config().model_type
+ self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type]
+
# TODO(Isotr0py): PP support
# self.make_empty_intermediate_tensors = (
# self.language_model.make_empty_intermediate_tensors)
def _parse_and_validate_image_input(
- self, **kwargs: object) -> Optional[Ovis2ImagePatchInputs]:
+ self, **kwargs: object) -> Optional[OvisImagePatchInputs]:
pixel_values = kwargs.pop("pixel_values", None)
indicator_tokens = kwargs.pop("indicator_tokens", None)
@@ -275,7 +442,7 @@ class Ovis2ForConditionalGeneration(nn.Module, SupportsMultiModal):
raise ValueError("Incorrect type of indicator_tokens. "
f"Got type: {type(pixel_values)}")
- return Ovis2ImagePatchInputs(
+ return OvisImagePatchInputs(
type="image_patches",
flat_data=flatten_bn(flatten_bn(pixel_values), concat=True),
patches_per_image=[
@@ -288,7 +455,7 @@ class Ovis2ForConditionalGeneration(nn.Module, SupportsMultiModal):
raise AssertionError("This line should be unreachable.")
def _process_image_input(
- self, image_input: Ovis2ImagePatchInputs) -> MultiModalEmbeddings:
+ self, image_input: OvisImagePatchInputs) -> MultiModalEmbeddings:
image_patches_flat = image_input["flat_data"]
patches_per_image = image_input["patches_per_image"]
indicator_tokens = image_input["indicator_tokens"]
@@ -338,7 +505,7 @@ class Ovis2ForConditionalGeneration(nn.Module, SupportsMultiModal):
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
- [IMAGE_PAD_TOKEN_ID])
+ self.image_pad_token_id)
return inputs_embeds
def forward(
@@ -375,8 +542,7 @@ class Ovis2ForConditionalGeneration(nn.Module, SupportsMultiModal):
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
- logits = self.llm.logits_processor(self.llm.lm_head, hidden_states,
- sampling_metadata)
+ logits = self.llm.compute_logits(hidden_states, sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index aef4566193c8..c5414e129dd1 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -195,7 +195,7 @@ _MULTIMODAL_MODELS = {
"Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"), # noqa: E501
"MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
"NVLM_D": ("nvlm_d", "NVLM_D_Model"),
- "Ovis2ForConditionalGeneration": ("ovis2", "Ovis2ForConditionalGeneration"),
+ "Ovis": ("ovis", "Ovis"),
"PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501
diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py
index db3efafeef96..ed10c22c84f0 100644
--- a/vllm/transformers_utils/configs/__init__.py
+++ b/vllm/transformers_utils/configs/__init__.py
@@ -23,7 +23,7 @@ from vllm.transformers_utils.configs.moonvit import MoonViTConfig
from vllm.transformers_utils.configs.mpt import MPTConfig
from vllm.transformers_utils.configs.nemotron import NemotronConfig
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
-from vllm.transformers_utils.configs.ovis2 import OvisConfig
+from vllm.transformers_utils.configs.ovis import OvisConfig
from vllm.transformers_utils.configs.skyworkr1v import SkyworkR1VChatConfig
from vllm.transformers_utils.configs.solar import SolarConfig
from vllm.transformers_utils.configs.telechat2 import Telechat2Config
diff --git a/vllm/transformers_utils/configs/ovis2.py b/vllm/transformers_utils/configs/ovis.py
similarity index 93%
rename from vllm/transformers_utils/configs/ovis2.py
rename to vllm/transformers_utils/configs/ovis.py
index 437a16e778c2..0ec224214f06 100644
--- a/vllm/transformers_utils/configs/ovis2.py
+++ b/vllm/transformers_utils/configs/ovis.py
@@ -123,6 +123,19 @@ class Aimv2VisualTokenizerConfig(BaseVisualTokenizerConfig):
self.backbone_kwargs['num_hidden_layers'] = self.depths[0]
+class SiglipVisualTokenizerConfig(BaseVisualTokenizerConfig):
+ model_type = "siglip_visual_tokenizer"
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ if self.drop_cls_token:
+ self.drop_cls_token = False
+ if self.depths:
+ assert len(self.depths) == 1
+ self.backbone_kwargs['num_hidden_layers'] = self.depths[0]
+
+
+AutoConfig.register("siglip_visual_tokenizer", SiglipVisualTokenizerConfig)
AutoConfig.register("aimv2_visual_tokenizer", Aimv2VisualTokenizerConfig)
diff --git a/vllm/transformers_utils/processors/__init__.py b/vllm/transformers_utils/processors/__init__.py
index 2e9cf3e4d90b..2bd9ab1f099b 100644
--- a/vllm/transformers_utils/processors/__init__.py
+++ b/vllm/transformers_utils/processors/__init__.py
@@ -2,6 +2,6 @@
from vllm.transformers_utils.processors.deepseek_vl2 import (
DeepseekVLV2Processor)
-from vllm.transformers_utils.processors.ovis2 import OvisProcessor
+from vllm.transformers_utils.processors.ovis import OvisProcessor
__all__ = ["DeepseekVLV2Processor", "OvisProcessor"]
diff --git a/vllm/transformers_utils/processors/ovis2.py b/vllm/transformers_utils/processors/ovis.py
similarity index 94%
rename from vllm/transformers_utils/processors/ovis2.py
rename to vllm/transformers_utils/processors/ovis.py
index a633256ec12c..48e786792cf5 100644
--- a/vllm/transformers_utils/processors/ovis2.py
+++ b/vllm/transformers_utils/processors/ovis.py
@@ -22,6 +22,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from functools import cached_property
from typing import List, Union
import PIL
@@ -32,7 +33,7 @@ from transformers.processing_utils import (ProcessingKwargs, ProcessorMixin,
Unpack)
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
-__all__ = [ 'OvisProcessor']
+__all__ = ['OvisProcessor']
IGNORE_ID = -100
class OvisProcessorKwargs(ProcessingKwargs, total=False): # type: ignore[call-arg]
@@ -64,18 +65,29 @@ class OvisProcessor(ProcessorMixin):
"""
attributes = ["image_processor", "tokenizer"]
- valid_kwargs = ["chat_template"]
+ valid_kwargs = ["chat_template", "image_pad_token", "image_segement_len"]
image_processor_class = "AutoImageProcessor"
- tokenizer_class = "Qwen2Tokenizer"
+ tokenizer_class = "AutoTokenizer"
- def __init__(self, image_processor=None, tokenizer=None, chat_template=None, image_pad_token=None, **kwargs):
+ def __init__(
+ self,
+ image_processor=None,
+ tokenizer=None,
+ chat_template=None,
+ image_pad_token=None,
+ image_segment_len=255,
+ **kwargs,
+ ):
self.image_token = ""
- self.image_pad_token = "<|image_pad|>" if image_pad_token is None else image_pad_token
+ self.image_pad_token = image_pad_token
+ self.image_segment_len = image_segment_len
super().__init__(image_processor, tokenizer, chat_template=chat_template)
- self.image_pad_token_id = self.tokenizer.get_vocab()[self.image_pad_token]
- self.extra_special_tokens = {
+ @cached_property
+ def extra_special_tokens(self):
+ image_pad_token_id = self.tokenizer.get_vocab()[self.image_pad_token]
+ extra_special_tokens = {
"image_token": -200,
"image_atom": -300,
"image_start": -301,
@@ -83,8 +95,9 @@ class OvisProcessor(ProcessorMixin):
"image_col_sep": -303,
"image_row_sep": -304,
"image_end": -305,
- 'image_pad': self.image_pad_token_id,
+ 'image_pad': image_pad_token_id,
}
+ return extra_special_tokens
def __call__(
self,
@@ -224,8 +237,14 @@ class OvisProcessor(ProcessorMixin):
return torch.tensor(batch_token_ids, dtype=torch.long)
def get_image_size(self):
- height = self.image_processor.crop_size["height"]
- width = self.image_processor.crop_size["width"]
+ size = self.image_processor.size
+ if 'shortest_edge' in size:
+ width = height = size['shortest_edge']
+ elif "height" in size and "width" in size:
+ width = size['width']
+ height = size['height']
+ else:
+ raise ValueError( "Can't parse image size from image_processor config.")
return height, width
def get_token_value(self, tok):
@@ -259,8 +278,7 @@ class OvisProcessor(ProcessorMixin):
for token in image_placeholders:
padded_placeholder_tokens.append(image_padding_token_id)
if token == image_atom_token_id:
- # Add 255 padding tokens after each image atom token
- padded_placeholder_tokens.extend([image_padding_token_id] * 255)
+ padded_placeholder_tokens.extend([image_padding_token_id] * self.image_segment_len)
return padded_placeholder_tokens
def preprocess_image(self, image: PIL.Image.Image, max_partition, covering_threshold, convert_to_rgb, return_tensors):