[Model] support new model ovis2.5 (#23084)

Signed-off-by: myselvess <244285088@qq.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
myselvess 2025-08-19 21:12:59 +08:00 committed by GitHub
parent f856c33ce9
commit b87cb97a53
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 1787 additions and 1 deletions

View File

@ -641,6 +641,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `MolmoForCausalLM` | Molmo | T + I<sup>+</sup> | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `NVLM_D_Model` | NVLM-D 1.0 | T + I<sup>+</sup> | `nvidia/NVLM-D-72B`, etc. | | ✅︎ | ✅︎ |
| `Ovis` | Ovis2, Ovis1.6 | T + I<sup>+</sup> | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ | ✅︎ |
| `Ovis2_5` | Ovis2.5 | T + I<sup>+</sup> + V | `AIDC-AI/Ovis2.5-9B`, etc. | | | ✅︎ |
| `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + I<sup>E</sup> | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ | ⚠️ |
| `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + I<sup>E+</sup> | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | ✅︎ |
| `Phi4MMForCausalLM` | Phi-4-multimodal | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |

View File

@ -1105,6 +1105,38 @@ def run_ovis(questions: list[str], modality: str) -> ModelRequestData:
)
# Ovis2_5
def run_ovis2_5(questions: list[str], modality: str) -> ModelRequestData:
model_name = "AIDC-AI/Ovis2.5-2B"
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=2,
trust_remote_code=True,
dtype="half",
limit_mm_per_prompt={modality: 1},
)
if modality == "image":
placeholder = "<image>"
elif modality == "video":
placeholder = "<video>"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
messages = [
[{"role": "user", "content": f"{placeholder}\n{question}"}]
for question in questions
]
prompts = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# PaliGemma
def run_paligemma(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
@ -1579,6 +1611,7 @@ model_example_map = {
"nemotron_vl": run_nemotron_vl,
"NVLM_D": run_nvlm_d,
"ovis": run_ovis,
"ovis2_5": run_ovis2_5,
"paligemma": run_paligemma,
"paligemma2": run_paligemma2,
"phi3_v": run_phi3v,

View File

@ -680,6 +680,36 @@ def load_ovis(question: str, image_urls: list[str]) -> ModelRequestData:
)
# ovis2_5
def load_ovis2_5(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "AIDC-AI/Ovis2.5-2B"
engine_args = EngineArgs(
model=model_name,
max_model_len=8192,
max_num_seqs=2,
trust_remote_code=True,
dtype="half",
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = "\n".join(
f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1)
)
messages = [{"role": "user", "content": f"{placeholders}\n{question}"}]
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image_data=[fetch_image(url) for url in image_urls],
)
def load_pixtral_hf(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "mistral-community/pixtral-12b"
@ -1155,6 +1185,7 @@ model_example_map = {
"mllama": load_mllama,
"NVLM_D": load_nvlm_d,
"ovis": load_ovis,
"ovis2_5": load_ovis2_5,
"phi3_v": load_phi3v,
"phi4_mm": load_phi4mm,
"phi4_multimodal": load_phi4_multimodal,

View File

@ -11,6 +11,7 @@ from pathlib import PosixPath
import pytest
from transformers import (AutoModel, AutoModelForImageTextToText,
AutoModelForTextToWaveform, AutoModelForVision2Seq)
from transformers.utils import is_flash_attn_2_available
from vllm.platforms import current_platform
from vllm.utils import identity
@ -621,6 +622,26 @@ VLM_TEST_SETTINGS = {
hf_model_kwargs={"llm_attn_implementation": "sdpa"},
patch_hf_runner=model_utils.ovis_patch_hf_runner,
),
"ovis2_5": VLMTestInfo(
models=["AIDC-AI/Ovis2.5-2B"],
test_type=(
VLMTestType.IMAGE,
VLMTestType.MULTI_IMAGE,
VLMTestType.VIDEO
),
prompt_formatter=lambda img_prompt: f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
img_idx_to_prompt=lambda idx: "<image>\n", # noqa: E501
video_idx_to_prompt=lambda idx: "<video>\n",
max_model_len=4096,
max_num_seqs=2,
dtype="half",
num_logprobs=10,
patch_hf_runner=model_utils.ovis2_5_patch_hf_runner,
marks=[pytest.mark.skipif(
not is_flash_attn_2_available(),
reason="HF model needs `flash_attn` installed"
)],
),
"phi3v": VLMTestInfo(
models=["microsoft/Phi-3.5-vision-instruct"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),

View File

@ -10,6 +10,7 @@ from typing import Optional, Union
import numpy as np
import numpy.typing as npt
import PIL.Image
import pytest
import regex as re
import torch
@ -810,6 +811,63 @@ def ovis_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
return hf_model
def ovis2_5_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for Ovis2."""
hf_model.model.get_output_embeddings = lambda: \
hf_model.model.llm.get_output_embeddings()
def processor(*args, text="", images=None, videos=None, **kwargs):
if images is None:
images = []
else:
images = [images] if isinstance(images, Image) else images
if videos is None:
videos = []
else:
videos = [videos] if isinstance(videos, np.ndarray) else videos
videos = [[PIL.Image.fromarray(frame) for frame in vid]
for vid in videos]
prompt_start_and_end = {
"qwen2": ("<|im_start|>user\n", "<|im_end|>\n"),
"llama":
("<|start_header_id|>user<|end_header_id|>\n\n", "<|eot_id|>"),
"gemma2": ("<start_of_turn>user\n", "<end_of_turn>\n"),
}
for start, end in prompt_start_and_end.values():
if start in text and end in text:
text = text.split(start)[1].split(end)[0]
break
images_message = [{"type": "image", "image": img} for img in images]
videos_message = [{"type": "video", "video": vid} for vid in videos]
messages = [{
"role":
"user",
"content": [
*images_message,
*videos_message,
{
"type": "text",
"text": text
},
],
}]
input_ids, pixel_values, grid_thws = hf_model.model.preprocess_inputs(
messages=messages, enable_thinking=True)
inputs = {
"inputs": input_ids,
"pixel_values": pixel_values,
"grid_thws": grid_thws,
}
return BatchFeature(data=inputs, tensor_type="pt")
hf_model.processor = processor
return hf_model
def qwen2_5_omni_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner for Qwen2.5-Omni."""
thinker = hf_model.model.thinker

View File

@ -162,6 +162,7 @@ def _test_processing_correctness(
_ADD_SPECIAL_TOKENS_OVERRIDES = {
"mllama": False,
"ovis": False,
"ovis2_5": False,
"paligemma": False,
"ultravox": False,
"whisper": False,
@ -301,6 +302,7 @@ def _test_processing_correctness_one(
"AIDC-AI/Ovis1.6-Gemma2-9B",
"AIDC-AI/Ovis1.6-Llama3.2-3B",
"AIDC-AI/Ovis2-1B",
"AIDC-AI/Ovis2.5-2B",
"google/paligemma-3b-mix-224",
"google/paligemma2-3b-ft-docci-448",
"microsoft/Phi-3.5-vision-instruct",

View File

@ -464,6 +464,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
transformers_version_reason="HF model is not compatible", # noqa: E501
extras={"1.6-llama": "AIDC-AI/Ovis1.6-Llama3.2-3B",
"1.6-gemma": "AIDC-AI/Ovis1.6-Gemma2-9B"}), # noqa: E501
"Ovis2_5": _HfExamplesInfo("AIDC-AI/Ovis2.5-2B", trust_remote_code=True,
max_transformers_version="4.53",
transformers_version_reason="HF model is not compatible"), # noqa: E501
"PaliGemmaForConditionalGeneration": _HfExamplesInfo("google/paligemma-3b-mix-224", # noqa: E501
extras={"v2": "google/paligemma2-3b-ft-docci-448"}), # noqa: E501
"Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct",

View File

@ -0,0 +1,570 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
""" PyTorch Ovis model."""
from collections.abc import Iterable, Mapping
from functools import partial
from typing import Optional, Union
import torch
import torch.nn as nn
from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig
from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.models.ovis import (OvisImagePatchInputs,
VisualEmbedding)
from vllm.model_executor.models.siglip2navit import Siglip2NavitModel
from vllm.model_executor.models.utils import (AutoWeightsLoader, flatten_bn,
init_vllm_registered_model,
maybe_prefix)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems)
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor
from .interfaces import MultiModalEmbeddings, SupportsMultiModal
IMAGE_TOKEN = "<image>"
VIDEO_TOKEN = "<video>"
INDICATOR_IDS = [-301, -302, -303, -304]
IMAGE_PAD_TOKEN_MAP = {
"gemma2": "<unused0>",
"llama": "<|reserved_special_token_0|>",
"qwen2": "<|image_pad|>",
"qwen3": "<|image_pad|>",
}
IMAGE_PAD_TOKEN_ID_MAP = {
"gemma2": 7,
"llama": 128002,
"qwen2": 151655,
"qwen3": 151655,
}
def _ovis2_5_field_config():
return dict(pixel_values=MultiModalFieldConfig.batched("image"),
grids=MultiModalFieldConfig.batched("image"),
indicator_tokens=MultiModalFieldConfig.batched("image"),
video_pixel_values=MultiModalFieldConfig.batched("video"),
video_indicator_tokens=MultiModalFieldConfig.batched("video"),
video_grids=MultiModalFieldConfig.batched("video"))
class VisualTokenizer(torch.nn.Module):
"""
VIT
"""
def __init__(
self,
config: PretrainedConfig,
visual_vocab_size: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.vit = self._init_backbone(
config=config,
quant_config=quant_config,
prefix=f"{prefix}.vit",
)
# reserved tokens for INDICATOR_IDS
head_dim = visual_vocab_size - len(INDICATOR_IDS)
self.head = torch.nn.Sequential(
ReplicatedLinear(
self.config.hidden_size * self.config.hidden_stride**2,
head_dim,
bias=False,
return_bias=False,
), torch.nn.LayerNorm(head_dim))
def _init_backbone(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
model_type = config.model_type
if model_type == "siglip2_navit":
return Siglip2NavitModel(config=config, )
raise ValueError(
f"Unsupported visual tokenizer model_type: {model_type}")
@property
def dtype(self):
return next(self.head.parameters()).dtype
@property
def device(self):
return next(self.head.parameters()).device
def tokenize(self, logits):
tokens = torch.softmax(logits, dim=-1,
dtype=torch.float32).to(logits.dtype)
return tokens
def encode(self, pixel_values, grid_thws):
features = self.vit(pixel_values,
grid_thws,
output_hidden_states=True,
return_dict=True)
# refer to qwen2.5-vl patchmerger
seq_len, _ = features.shape
features = features.reshape(seq_len // (self.config.hidden_stride**2),
-1)
return features
def forward(self, pixel_values, grid_thws) -> torch.Tensor:
features = self.encode(pixel_values, grid_thws)
logits = self.head(features)
tokens = self.tokenize(logits)
# tokens' shape is [#Token, VocabSize-4],
# so padding with [#Token, 4], after which,
# tokens' shape should become [#Token, VocabSize];
tokens = torch.nn.functional.pad(
tokens,
(0, len(INDICATOR_IDS)),
mode="constant",
value=0,
)
return tokens
class Ovis2_5ProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config()
def get_hf_processor(self, **kwargs):
vit_config = self.get_hf_config().vit_config
return self.ctx.get_hf_processor(
Ovis2_5Processor,
image_pad_token=self.get_image_pad_token(),
patch_size=vit_config.patch_size,
hidden_stride=vit_config.hidden_stride,
temporal_patch_size=vit_config.temporal_patch_size,
)
def get_image_pad_token(self) -> str:
hf_text_config = self.get_hf_config().get_text_config()
text_model_type = hf_text_config.model_type
return IMAGE_PAD_TOKEN_MAP.get(text_model_type)
def get_image_processor(self) -> BaseImageProcessor:
return self.get_hf_processor().image_processor # type: ignore
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": 1}
def get_image_size_with_most_features(self) -> ImageSize:
# NOTE(myselvess): max_pixels 1792 * 1792 hardcoded in original code
# TODO(myselvess): Be adjusted based on the max_pixels
return ImageSize(width=1792, height=1792)
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
num_frames: int = 1,
) -> tuple[ImageSize, int]:
hf_config = self.get_hf_config()
vit_config = hf_config.vit_config
patch_size = vit_config.patch_size
temporal_patch_size = vit_config.temporal_patch_size
# NOTE: Frames are padded to be divisible by `temporal_patch_size`
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
padded_num_frames = num_frames + (-num_frames % temporal_patch_size)
grid_t = max(padded_num_frames // temporal_patch_size, 1)
grid_h = image_height // patch_size
grid_w = image_width // patch_size
num_patches = grid_t * grid_h * grid_w
num_vision_tokens = num_patches
return num_vision_tokens
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(image_width=target_width,
image_height=target_height)
def _get_max_video_frames(self, max_tokens: int) -> int:
target_width, target_height = self.get_image_size_with_most_features()
num_frames = 0
while True:
next_num_frames = num_frames + 1
next_max_tokens = self.get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=next_num_frames,
image_processor=None,
)
if next_max_tokens > max_tokens:
break
num_frames = next_num_frames
return num_frames
def get_num_frames_with_most_features(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> int:
max_images = mm_counts.get("image", 0)
max_videos = mm_counts.get("video", 0)
max_image_tokens = self.get_max_image_tokens() * max_images
max_total_frames = self._get_max_video_frames(seq_len -
max_image_tokens)
max_frames_per_video = max_total_frames // max(max_videos, 1)
return max(max_frames_per_video, 1)
def get_num_video_tokens(
self,
*,
image_width: int,
image_height: int,
num_frames: int,
image_processor: Optional[BaseImageProcessor],
) -> int:
num_video_tokens = self.get_num_image_tokens(image_width=image_width,
image_height=image_height,
num_frames=num_frames)
return num_video_tokens
def get_max_video_tokens(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=self.get_num_frames_with_most_features(
seq_len, mm_counts),
image_processor=None,
)
class Ovis2_5DummyInputsBuilder(BaseDummyInputsBuilder[Ovis2_5ProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
return IMAGE_TOKEN * num_images + VIDEO_TOKEN * num_videos
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
target_width, target_height = \
self.info.get_image_size_with_most_features()
target_num_frames = \
self.info.get_num_frames_with_most_features(seq_len, mm_counts)
mm_data = {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images),
"video":
self._get_dummy_videos(
width=target_width,
height=target_height,
num_frames=target_num_frames,
num_videos=num_videos,
)
}
return mm_data
class Ovis2_5MultiModalProcessor(BaseMultiModalProcessor[Ovis2_5ProcessingInfo]
):
def visual_indicators_to_visual_tokens(
self,
visual_indicators: list[int],
) -> list[int]:
"""
Filter image indicators placeholders and convert them to corresponding
tokens in visual tokenizer.
"""
hf_config = self.info.get_hf_config()
vte_vocab_size = hf_config.visual_vocab_size
return [
vte_vocab_size - len(INDICATOR_IDS) + abs(x + 300) - 1
for x in visual_indicators if x < -300
]
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
if not mm_data:
# Avoid warning from HF logger for text-only input
tokenizer = self.info.get_tokenizer()
prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
hf_processor = self.info.get_hf_processor()
if "videos" in mm_data:
visual_indicators = [
hf_processor.construct_visual_indicators((1, 1, 1), True)
for grid in processed_outputs["video_grids"]
]
indicator_tokens = [
self.visual_indicators_to_visual_tokens(indicator)
for indicator in visual_indicators
]
processed_outputs["video_indicator_tokens"] = indicator_tokens
if "images" in mm_data:
visual_indicators = [
hf_processor.construct_visual_indicators((1, 1, 1), False)
for grid in processed_outputs["grids"]
]
indicator_tokens = [
self.visual_indicators_to_visual_tokens(indicator)
for indicator in visual_indicators
]
processed_outputs["indicator_tokens"] = indicator_tokens
return processed_outputs
def _apply_hf_processor_tokens_only(
self,
prompt_tokens: list[int],
) -> list[int]:
return prompt_tokens
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return _ovis2_5_field_config()
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> list[PromptReplacement]:
def get_replacement_ovis(item_idx, modality: str):
if modality == "image":
out_item = out_mm_kwargs["image"][item_idx]
grid = out_item["grids"].data
elif modality == "video":
out_item = out_mm_kwargs["video"][item_idx]
grid = out_item["video_grids"].data
hf_processor = self.info.get_hf_processor()
return hf_processor.construct_visual_placeholders(grid[0], )
return [
PromptReplacement(
modality=modality,
target=IMAGE_TOKEN if modality == "image" else VIDEO_TOKEN,
replacement=partial(get_replacement_ovis, modality=modality),
) for modality in ("image", "video")
]
@MULTIMODAL_REGISTRY.register_processor(Ovis2_5MultiModalProcessor,
info=Ovis2_5ProcessingInfo,
dummy_inputs=Ovis2_5DummyInputsBuilder)
class Ovis2_5(nn.Module, SupportsMultiModal):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config: PretrainedConfig = config
self.llm = init_vllm_registered_model(
vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=maybe_prefix(prefix, "llm"),
)
self.visual_tokenizer = VisualTokenizer(
config=config.vit_config,
visual_vocab_size=config.visual_vocab_size,
quant_config=quant_config,
prefix=f"{prefix}.visual_tokenizer",
)
self.vte = VisualEmbedding(config.visual_vocab_size,
config.hidden_size)
text_model_type = self.config.get_text_config().model_type
self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type]
# TODO(Isotr0py): PP support
# self.make_empty_intermediate_tensors = (
# self.language_model.make_empty_intermediate_tensors)
def _parse_and_validate_visual_input(
self, is_video,
**kwargs: object) -> Optional[OvisImagePatchInputs]:
if is_video:
pixel_values = kwargs.pop("video_pixel_values", None)
indicator_tokens = kwargs.pop("video_indicator_tokens", None)
grids = kwargs.pop("video_grids", None)
else:
pixel_values = kwargs.pop("pixel_values", None)
indicator_tokens = kwargs.pop("indicator_tokens", None)
grids = kwargs.pop("grids", None)
if pixel_values is None and indicator_tokens is None:
return None
if pixel_values is not None and indicator_tokens is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if not isinstance(indicator_tokens, (torch.Tensor, list)):
raise ValueError("Incorrect type of indicator_tokens. "
f"Got type: {type(indicator_tokens)}")
return OvisImagePatchInputs(
type="image_patches",
flat_data=flatten_bn(flatten_bn(pixel_values), concat=True),
patches_per_image=[
x.shape[0] // (self.config.vit_config.hidden_stride**2)
for x in flatten_bn(pixel_values)
],
indicator_tokens=flatten_bn(flatten_bn(indicator_tokens),
concat=True),
grids=flatten_bn(flatten_bn(grids), concat=True),
)
raise AssertionError("This line should be unreachable.")
def _process_image_input(
self, image_input: OvisImagePatchInputs) -> MultiModalEmbeddings:
image_patches_flat = image_input["flat_data"]
patches_per_image = image_input["patches_per_image"]
indicator_tokens = image_input["indicator_tokens"]
grid_thws = image_input["grids"]
indicator_per_image = list(
map(lambda x: 2 if x > 1 else x + 2, patches_per_image))
target_dtype = self.visual_tokenizer.dtype
visual_tokens = self.visual_tokenizer(
image_patches_flat.to(target_dtype), grid_thws)
visual_embeds = self.vte(visual_tokens) # 1:1 numeric eq.
indicator_embeds = self.vte(indicator_tokens)
visual_embeds_per_image = visual_embeds.split(patches_per_image, dim=0)
indicator_embeds_per_image = indicator_embeds.split(
indicator_per_image)
vision_embeddings = []
for indicator, visual in zip(indicator_embeds_per_image,
visual_embeds_per_image):
vision_embeddings_per_image = []
visual = visual.unsqueeze(0)
for i in range(visual.shape[0]):
vision_embeddings_per_image.append(
torch.cat([indicator[i:i + 1], visual[i]], dim=0))
vision_embeddings_per_image.append(indicator[i + 1:])
vision_embeddings.append(
torch.cat(vision_embeddings_per_image, dim=0))
return tuple(vision_embeddings)
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
embeddings = []
# NOTE: _parse_and_validate_visual_input has side-effects and pops
# keys from kwargs. We process images first, then videos.
image_input = self._parse_and_validate_visual_input(False, **kwargs)
if image_input:
embeddings.extend(self._process_image_input(image_input))
video_input = self._parse_and_validate_visual_input(True, **kwargs)
if video_input:
embeddings.extend(self._process_image_input(video_input))
return tuple(embeddings) if embeddings else None
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.llm.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
tmp = torch.concat(multimodal_embeddings, dim=0)
inputs_embeds[input_ids == self.image_pad_token_id] = tmp
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
if intermediate_tensors is not None:
inputs_embeds = None
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None
# up until here we have a inputs_embeds 100% numerical identity
# between the OG HF Transformers implementation and ours
hidden_states = self.llm(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.llm.compute_logits(hidden_states, sampling_metadata)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
def get_language_model(self) -> torch.nn.Module:
return self.llm

View File

@ -231,6 +231,7 @@ _MULTIMODAL_MODELS = {
"MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
"NVLM_D": ("nvlm_d", "NVLM_D_Model"),
"Ovis": ("ovis", "Ovis"),
"Ovis2_5": ("ovis2_5", "Ovis2_5"),
"PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),

View File

@ -0,0 +1,607 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Implementation of SiglipVisionModel intended to be only used
within a vision language model."""
from typing import Optional, Union
import torch
from einops import rearrange, repeat
from torch import nn
from torch.nn import functional as F
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import BaseModelOutputWithNoAttention
from vllm.platforms import _Backend
from .vision import get_vit_attn_backend
class VisionRotaryEmbedding(nn.Module):
def __init__(self, dim: int, theta: float = 10000.0) -> None:
super().__init__()
inv_freq = 1.0 / (theta
**(torch.arange(0, dim, 2, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, seqlen: int) -> torch.Tensor:
seq = torch.arange(seqlen,
device=self.inv_freq.device,
dtype=self.inv_freq.dtype)
freqs = torch.outer(seq, self.inv_freq)
return freqs
class Siglip2VisionEmbeddings(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.patch_size = config.patch_size
self.image_size = config.image_size
self.num_patches = config.num_patches
self.preserve_original_pe = config.preserve_original_pe
self.hidden_stride = config.hidden_stride
# siglip2 naflex
if self.num_patches > 0:
self.patch_embedding = nn.Linear(
in_features=config.num_channels * self.patch_size *
self.patch_size,
out_features=self.embed_dim,
)
if self.preserve_original_pe:
self.position_embedding_size = int(self.num_patches**0.5)
self.position_embedding = nn.Embedding(self.num_patches,
self.embed_dim)
else:
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
)
if self.preserve_original_pe:
self.num_patches = (self.image_size // self.patch_size)**2
self.position_embedding_size = (self.image_size //
self.patch_size)
self.position_embedding = nn.Embedding(self.num_patches,
self.embed_dim)
def forward(self,
pixel_values: torch.FloatTensor,
grid_thws: Optional[torch.LongTensor] = None) -> torch.Tensor:
"""
Args:
pixel_values (`torch.FloatTensor`):
Pixel values of shape (
num_patches,
num_channels * temporal_patch_size * patch_size * patch_size
)
grid_thws: (`torch.LongTensor`):
grid shape (num_patches, 3)
"""
# Apply patch embeddings to already patchified pixel values
target_dtype = self.patch_embedding.weight.dtype
if isinstance(self.patch_embedding, nn.Linear):
patch_embeds = self.patch_embedding(
pixel_values.to(dtype=target_dtype))
elif isinstance(self.patch_embedding, nn.Conv2d):
pixel_values = pixel_values.view(
-1, self.config.num_channels * self.config.temporal_patch_size,
self.patch_size, self.patch_size)
patch_embeds = self.patch_embedding(
pixel_values.to(dtype=target_dtype))
patch_embeds = patch_embeds.reshape(-1, self.embed_dim)
if self.preserve_original_pe:
assert grid_thws is not None
pos_embed_new = torch.zeros_like(patch_embeds)
positional_embeddings = self.position_embedding.weight.reshape(
self.position_embedding_size, self.position_embedding_size,
-1).unsqueeze(0).permute(0, 3, 1, 2)
cnt = 0
for t, h, w in grid_thws:
volume = t * h * w
pe = F.interpolate(positional_embeddings,
size=(h, w),
mode='bicubic',
align_corners=False)
pe = pe.permute(0, 2, 3, 1).reshape(1, h * w, -1)
pe = pe[0].repeat(t, 1)
pe = pe.reshape(t, h // self.hidden_stride, self.hidden_stride,
w // self.hidden_stride, self.hidden_stride,
-1)
pe = pe.permute(0, 1, 3, 2, 4, 5).reshape(volume, -1)
pos_embed_new[cnt:cnt + volume] = pe
cnt += volume
patch_embeds = patch_embeds + pos_embed_new
return patch_embeds
# copy from flash_attn/layers/rotary.py
def rotate_half(x, interleaved=False):
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(torch.stack((-x2, x1), dim=-1),
"... d two -> ... (d two)",
two=2)
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(
cos,
"... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
sin = repeat(
sin,
"... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
return torch.cat(
[
x[..., :ro_dim] * cos +
rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]
],
dim=-1,
)
def apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
is_flash_attn_backend: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous()
if is_flash_attn_backend:
from flash_attn.layers.rotary import apply_rotary_emb
apply_rotary_emb_func = apply_rotary_emb
else:
apply_rotary_emb_func = apply_rotary_emb_torch
q_embed = apply_rotary_emb_func(q.float(), cos.float(),
sin.float()).type_as(q)
k_embed = apply_rotary_emb_func(k.float(), cos.float(),
sin.float()).type_as(k)
return q_embed, k_embed
class Siglip2Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads}).")
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.is_causal = False
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.use_rope = config.use_rope
# Detect attention implementation.
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA,
_Backend.ROCM_AITER_FA
}:
self.attn_backend = _Backend.TORCH_SDPA
self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
}
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
position_embeddings: Optional[tuple[torch.Tensor,
torch.Tensor]] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""
seq_length, embed_dim = hidden_states.shape
queries = self.q_proj(hidden_states)
keys = self.k_proj(hidden_states)
values = self.v_proj(hidden_states)
queries = queries.view(seq_length, self.num_heads, self.head_dim)
keys = keys.view(seq_length, self.num_heads, self.head_dim)
values = values.view(seq_length, self.num_heads, self.head_dim)
if self.use_rope:
cos, sin = position_embeddings
queries, keys = apply_rotary_pos_emb(queries.unsqueeze(0),
keys.unsqueeze(0), cos, sin,
self.is_flash_attn_backend)
queries = queries.squeeze(0)
keys = keys.squeeze(0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
if self.is_flash_attn_backend:
if self.attn_backend == _Backend.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
from flash_attn import flash_attn_varlen_func
attn_output = flash_attn_varlen_func(
queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen,
max_seqlen).reshape(seq_length, -1)
elif self.attn_backend == _Backend.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.
batch_size = cu_seqlens.shape[0] - 1
outputs = []
cu = cu_seqlens.tolist()
for i in range(batch_size):
start_idx = cu[i]
end_idx = cu[i + 1]
# Each sequence is processed independently.
q_i = queries[start_idx:end_idx].unsqueeze(0)
k_i = keys[start_idx:end_idx].unsqueeze(0)
v_i = values[start_idx:end_idx].unsqueeze(0)
# (1, seq_len, num_heads, head_dim) ->
# (1, num_heads, seq_len, head_dim)
q_i, k_i, v_i = [x.transpose(1, 2) for x in (q_i, k_i, v_i)]
output_i = F.scaled_dot_product_attention(q_i,
k_i,
v_i,
dropout_p=0.0)
# (1, num_heads, seq_len, head_dim) -> (seq_len, embed_dim)
output_i = output_i.transpose(1, 2).reshape(-1, self.embed_dim)
outputs.append(output_i)
attn_output = torch.cat(outputs, dim=0)
attn_output = self.out_proj(attn_output)
return attn_output
class Siglip2MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class Siglip2EncoderLayer(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
self.self_attn = Siglip2Attention(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
self.mlp = Siglip2MLP(config)
def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
position_embeddings: torch.Tensor) -> tuple[torch.FloatTensor]:
"""
Args:
hidden_states (`torch.FloatTensor`):
Input to the layer of shape `(batch, seq_len, embed_dim)`.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all
attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states,
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class Siglip2Encoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers`
self attention layers. Each layer is a [`Siglip2EncoderLayer`].
Args:
config: PretrainedConfig
"""
def __init__(self, config: PretrainedConfig):
super().__init__()
self.config = config
self.layers = nn.ModuleList([
Siglip2EncoderLayer(config)
for _ in range(config.num_hidden_layers)
])
self.gradient_checkpointing = False
self.rotary_pos_emb = VisionRotaryEmbedding(
config.hidden_size // config.num_attention_heads // 2)
self.patch_size = config.patch_size
self.hidden_stride = config.hidden_stride
self.window_size = config.window_size
self.spatial_merge_unit = config.hidden_stride * config.hidden_stride
if config.fullatt_block_indexes is None:
self.fullatt_block_indexes = None
else:
self.fullatt_block_indexes = [
int(i) for i in config.fullatt_block_indexes.split('|')
]
# copied from qwen2.5_vl
def rot_pos_emb(self, grid_thw):
pos_ids = []
for t, h, w in grid_thw:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.reshape(
h // self.hidden_stride,
self.hidden_stride,
w // self.hidden_stride,
self.hidden_stride,
)
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
hpos_ids = hpos_ids.flatten()
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
wpos_ids = wpos_ids.reshape(
h // self.hidden_stride,
self.hidden_stride,
w // self.hidden_stride,
self.hidden_stride,
)
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
wpos_ids = wpos_ids.flatten()
pos_ids.append(
torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
def get_window_index(self, grid_thw):
window_index: list = []
cu_window_seqlens: list = [0]
window_index_id = 0
# patch (after merge) number in each window
vit_merger_window_size = (self.window_size // self.hidden_stride //
self.patch_size)
for grid_t, grid_h, grid_w in grid_thw:
llm_grid_h, llm_grid_w = (
grid_h // self.hidden_stride, # number of patch after merge
grid_w // self.hidden_stride,
)
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
grid_t, llm_grid_h, llm_grid_w)
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
index_padded = index_padded.reshape(
grid_t,
num_windows_h,
vit_merger_window_size,
num_windows_w,
vit_merger_window_size,
)
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
grid_t,
num_windows_h * num_windows_w,
vit_merger_window_size,
vit_merger_window_size,
)
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
index_padded = index_padded.reshape(-1)
index_new = index_padded[index_padded != -100]
window_index.append(index_new + window_index_id)
cu_seqlens_tmp = seqlens.cumsum(
0) * self.spatial_merge_unit + cu_window_seqlens[-1]
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
window_index = torch.cat(window_index, dim=0)
return window_index, cu_window_seqlens
# Ignore copy
def forward(
self,
inputs_embeds,
grid_thws: torch.Tensor,
output_hidden_states: bool = False,
) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, ...]]]:
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape
`(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to
directly pass an embedded representation. This is useful if
you want more control over how to convert `input_ids` indices
into associated vectors than the model's internal embedding
lookup matrix.
grid_thws (`torch.LongTensor`):
grid shape (num_patches, 3)
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See
`hidden_states` under returned tensors for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of
a plain tuple.
"""
rotary_pos_emb = self.rot_pos_emb(grid_thws)
window_index, cu_window_seqlens = self.get_window_index(grid_thws)
cu_window_seqlens = torch.tensor(
cu_window_seqlens,
device=inputs_embeds.device,
dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
seq_len, _ = inputs_embeds.size()
inputs_embeds = inputs_embeds.reshape(
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
inputs_embeds = inputs_embeds[window_index, :, :]
inputs_embeds = inputs_embeds.reshape(seq_len, -1)
rotary_pos_emb = rotary_pos_emb.reshape(
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
position_embeddings = (emb.cos(), emb.sin())
cu_seqlens = torch.repeat_interleave(
grid_thws[:, 1] * grid_thws[:, 2], grid_thws[:, 0]
).cumsum(
dim=0,
# Select dtype based on the following factors:
# - FA2 requires that cu_seqlens_q must have dtype int32
# - torch.onnx.export requires that cu_seqlens_q must have
# same dtype as grid_thw
# See https://github.com/huggingface/transformers/pull/34852
# for more information
dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
reverse_indices = torch.argsort(window_index)
encoder_states = () if output_hidden_states else None
hidden_states = inputs_embeds
for index, block in enumerate(self.layers):
if (not self.fullatt_block_indexes
or index in self.fullatt_block_indexes):
cu_seqlens_tmp = cu_seqlens
else:
cu_seqlens_tmp = cu_window_seqlens
hidden_states = block(hidden_states, cu_seqlens_tmp,
position_embeddings)
if output_hidden_states:
hidden_states_ = hidden_states.reshape(
seq_len // self.spatial_merge_unit,
self.spatial_merge_unit, -1)
encoder_states += (hidden_states_[reverse_indices, :].reshape(
seq_len, -1), )
# tokens = self.post_trunk_norm(tokens)
hidden_states = hidden_states.reshape(
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
hidden_states = hidden_states[reverse_indices, :].reshape(seq_len, -1)
return hidden_states, encoder_states
class Siglip2VisionTransformer(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = Siglip2VisionEmbeddings(config)
self.encoder = Siglip2Encoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim,
eps=config.layer_norm_eps)
self._use_flash_attention_2 = \
(config._attn_implementation == "flash_attention_2")
def forward(
self,
pixel_values: torch.FloatTensor,
grid_thws: torch.LongTensor,
output_hidden_states: Optional[bool] = True,
return_dict: Optional[bool] = True,
) -> Union[
tuple[torch.Tensor],
tuple[torch.Tensor, tuple[torch.Tensor, ...]],
BaseModelOutputWithNoAttention,
]:
r"""
spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
Tensor containing the spatial dimensions (height, width)
of the input images.
"""
hidden_states = self.embeddings(pixel_values, grid_thws)
last_hidden_state, hidden_states = self.encoder(
hidden_states, grid_thws, output_hidden_states)
last_hidden_state = self.post_layernorm(last_hidden_state)
if not return_dict:
output = (last_hidden_state, )
output += (hidden_states, ) if output_hidden_states else ()
return output
return last_hidden_state
class Siglip2NavitModel(torch.nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.vision_model = Siglip2VisionTransformer(config)
def forward(
self,
pixel_values: torch.FloatTensor,
grid_thws: torch.LongTensor,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[
tuple[torch.Tensor],
tuple[torch.Tensor, tuple[torch.Tensor, ...]],
BaseModelOutputWithNoAttention,
]:
if output_hidden_states is None:
output_hidden_states = self.config.output_hidden_states
if return_dict is None:
return_dict = self.config.use_return_dict
return self.vision_model(
pixel_values=pixel_values,
grid_thws=grid_thws,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

View File

@ -11,5 +11,6 @@ reasons:
from vllm.transformers_utils.processors.deepseek_vl2 import (
DeepseekVLV2Processor)
from vllm.transformers_utils.processors.ovis import OvisProcessor
from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor
__all__ = ["DeepseekVLV2Processor", "OvisProcessor"]
__all__ = ["DeepseekVLV2Processor", "OvisProcessor", "Ovis2_5Processor"]

View File

@ -0,0 +1,458 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from functools import cached_property
from typing import Optional, Union
import numpy as np
import PIL
import torch
from transformers import AutoProcessor, BatchFeature
from transformers.image_utils import ImageInput
from transformers.processing_utils import (ProcessingKwargs, ProcessorMixin,
Unpack)
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
__all__ = ['Ovis2_5Processor']
IMAGE_TOKEN = "<image>"
VIDEO_TOKEN = "<video>"
MIN_PIXELS = 448 * 448
MAX_PIXELS = 1792 * 1792
class Ovis2_5ProcessorKwargs(ProcessingKwargs,
total=False): # type: ignore[call-arg]
_defaults = {
"text_kwargs": {
"padding": False,
},
"images_kwargs": {
'convert_to_rgb': True,
'min_pixels': MIN_PIXELS,
'max_pixels': MAX_PIXELS,
},
"videos_kwargs": {
'convert_to_rgb': True,
'min_pixels': MIN_PIXELS,
'max_pixels': MAX_PIXELS,
}
}
class Ovis2_5Processor(ProcessorMixin):
r"""
Constructs a Ovis processor which wraps a Ovis image processor
and a Qwen2 tokenizer into a single processor.
[`OvisProcessor`] offers all the functionalities of
[`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`].
See the [`~OvisProcessor.__call__`] and [`~OvisProcessor.decode`]
for more information.
Args:
image_processor ([`Qwen2VLImageProcessor`], *optional*):
The image processor is a required input.
tokenizer ([`Qwen2TokenizerFast`], *optional*):
The tokenizer is a required input.
chat_template (`str`, *optional*): A Jinja template which will
be used to convert lists of messages in a chat into
a tokenizable string.
"""
attributes = ["image_processor", "tokenizer"]
valid_kwargs = ["chat_template", "image_pad_token"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(
self,
image_processor=None,
tokenizer=None,
chat_template=None,
image_pad_token=None,
patch_size=16,
hidden_stride=2,
temporal_patch_size=1,
**kwargs,
):
self.image_token = IMAGE_TOKEN
self.video_token = VIDEO_TOKEN
self.image_pad_token = "<|image_pad|>"
self.patch_size = patch_size
self.hidden_stride = hidden_stride
self.temporal_patch_size = temporal_patch_size
super().__init__(image_processor,
tokenizer,
chat_template=chat_template)
@cached_property
def extra_special_tokens(self):
image_pad_token_id = self.tokenizer.get_vocab()[self.image_pad_token]
extra_special_tokens = {
"image_token": -200,
"video_token": -201,
"visual_atom": -300,
"image_start": -301,
"image_end": -302,
"video_start": -303,
"video_end": -304,
'image_pad': image_pad_token_id,
}
return extra_special_tokens
def __call__(
self,
images: ImageInput = None,
videos: Union[np.ndarray, list[ImageInput]] = None,
text: Union[TextInput, PreTokenizedInput, list[TextInput],
list[PreTokenizedInput]] = None,
**kwargs: Unpack[Ovis2_5ProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s)
and image(s). This method forwards the `text`and `kwargs` arguments
to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text`
is not `None` to encode the text. To prepare the vision inputs,
this method forwards the `vision_infos` and `kwrags` arguments to
Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`]
if `vision_infos` is not `None`.
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`,
`list[PIL.Image.Image]`, `list[np.ndarray]`,
`list[torch.Tensor]`):
The image or batch of images to be prepared.
Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats
are supported.
text (`str`, `list[str]`, `list[list[str]]`):
The sequence or batch of sequences to be encoded.
Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as
list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with
a batch of sequences).
videos (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`,
`list[torch.Tensor]`):
The image or batch of videos to be prepared. Each video
can be a 4D NumPy array or PyTorch tensor, or a nested
list of 3D frames. Both channels-first and channels-last
formats are supported.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework.
Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- list of token ids to be fed to a model.
Returned when `text` is not `None`.
- **attention_mask** -- list of indices specifying which tokens
should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"*
is in `self.model_input_names` and if `text` is not `None`).
- **pixel_values** -- Pixel values to be fed to a model.
Returned when `images` is not `None`.
- **pixel_values_videos** -- Pixel values of videos to be fed to
a model. Returned when `videos` is not `None`.
- **image_grid_thw** -- list of image 3D grid in LLM. Returned
when `images` is not `None`.
- **video_grid_thw** -- list of video 3D grid in LLM. Returned
when `videos` is not `None`.
- **second_per_grid_ts** -- list of video seconds per time grid.
Returned when `videos` is not `None`.
"""
output_kwargs = self._merge_kwargs(
Ovis2_5ProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
# Process all images first
visual_features = {}
output = BatchFeature()
if images is not None:
processed_images = []
image_placeholders_list = []
grids = []
# Process each image
for image in images if isinstance(images, list) else [images]:
pixel_values, image_placeholders, grid = (
self.preprocess_multidata(
images=image, **output_kwargs["images_kwargs"]))
processed_images.append(pixel_values)
image_placeholders_list.append(image_placeholders)
grids.append(grid)
# assign all processed images
if processed_images:
visual_features["image_placeholders"] = image_placeholders_list
output["pixel_values"] = processed_images
output["grids"] = grids
if videos is not None:
processed_videos = []
videos_placeholders_list = []
grids = []
# Process each video
for video in videos if isinstance(videos, list) else [videos]:
pixel_values, video_placeholders, grid = (
self.preprocess_multidata(
video=video, **output_kwargs["videos_kwargs"]))
processed_videos.append(pixel_values)
videos_placeholders_list.append(video_placeholders)
grids.append(grid)
# assign all processed videos
if processed_videos:
visual_features[
"video_placeholders"] = videos_placeholders_list
output["video_pixel_values"] = processed_videos
output["video_grids"] = grids
# Process text input
if text is not None:
if not isinstance(text, list):
text = [text]
tokenized_batched_text = self._tokenize_with_visual_symbol(text)
image_token_id = self.get_token_value("image_token")
video_token_id = self.get_token_value("video_token")
replaced_ids_list = []
image_idx = 0
video_idx = 0
for ids_tensor in tokenized_batched_text:
has_image_tokens = (image_token_id in ids_tensor
and "image_placeholders" in visual_features
and image_idx < len(
visual_features["image_placeholders"]))
has_video_tokens = (video_token_id in ids_tensor
and "video_placeholders" in visual_features
and video_idx < len(
visual_features["video_placeholders"]))
if has_image_tokens or has_video_tokens:
# Convert to list for easier manipulation
ids_list = ids_tensor.tolist()
new_ids = []
# Replace placeholders
for token_id in ids_list:
if token_id == image_token_id:
new_ids.extend(
visual_features["image_placeholders"]
[image_idx])
image_idx += 1
elif token_id == video_token_id:
new_ids.extend(
visual_features["video_placeholders"]
[video_idx])
video_idx += 1
else:
new_ids.append(token_id)
# Convert back to tensor
ids_tensor = torch.tensor(new_ids, dtype=torch.long)
replaced_ids_list.append(ids_tensor)
if replaced_ids_list:
replaced_and_tokenized_ids = torch.stack(replaced_ids_list)
else:
replaced_and_tokenized_ids = torch.tensor([], dtype=torch.long)
output["input_ids"] = replaced_and_tokenized_ids
return output
# If only images were provided
return BatchFeature(data=visual_features)
def _tokenize_with_visual_symbol(self,
text_list: list[str]) -> torch.LongTensor:
batch_token_ids = []
for text in text_list:
token_ids = []
video_token_id = self.get_token_value("video_token")
image_token_id = self.get_token_value("image_token")
video_split_texts = text.split(self.video_token)
for j, video_segment in enumerate(video_split_texts):
image_split_texts = video_segment.split(self.image_token)
text_chunks = [
self.tokenizer(chunk, add_special_tokens=False).input_ids
for chunk in image_split_texts
]
segment_tokens = []
for i, chunk in enumerate(text_chunks):
segment_tokens.extend(chunk)
if i < len(text_chunks) - 1:
segment_tokens.append(image_token_id)
token_ids.extend(segment_tokens)
if j < len(video_split_texts) - 1:
token_ids.append(video_token_id)
batch_token_ids.append(token_ids)
return torch.tensor(batch_token_ids, dtype=torch.long)
# Copied from qwen2_vl
def smart_resize(self,
height: int,
width: int,
factor: int = 28,
min_pixels: int = MIN_PIXELS,
max_pixels: int = MAX_PIXELS):
"""Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range
['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if height < factor or width < factor:
print(f"height:{height} or width:{width} must be "
f"larger than factor:{factor}")
if height < width:
width = round(factor / height * width)
height = factor
else:
height = round(factor / width * height)
width = factor
elif max(height, width) / min(height, width) > 200:
print(f"absolute aspect ratio must be smaller than 200, "
f"got {max(height, width) / min(height, width)}")
if height > width:
height = 200 * width
else:
width = 200 * height
h_bar = round(height / factor) * factor
w_bar = round(width / factor) * factor
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = math.floor(height / beta / factor) * factor
w_bar = math.floor(width / beta / factor) * factor
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = math.ceil(height * beta / factor) * factor
w_bar = math.ceil(width * beta / factor) * factor
return h_bar, w_bar
def get_token_value(self, tok):
return self.extra_special_tokens[tok]
def construct_visual_indicators(self, grid, is_video: bool = False):
if is_video:
start_token = self.get_token_value('video_start')
end_token = self.get_token_value('video_end')
else:
start_token = self.get_token_value('image_start')
end_token = self.get_token_value('image_end')
image_placeholders = [start_token, self.get_token_value('visual_atom')]
if grid[0] * grid[1] > 1:
for r in range(grid[0]):
for c in range(grid[1]):
image_placeholders.append(
self.get_token_value('visual_atom'))
image_placeholders.append(end_token)
return image_placeholders
def construct_visual_placeholders(self, grid, is_video: bool = False):
visual_placeholders = self.construct_visual_indicators((1, 1),
is_video)
image_atom_token_id = self.get_token_value('visual_atom')
# Extract the padding token ID from tokenizer
image_padding_token_id = self.get_token_value('image_pad')
num_image_atoms = grid[0] * grid[1] * grid[2]
num_image_atoms //= self.hidden_stride**2
num_image_atoms //= self.temporal_patch_size
# Create a new list with padding tokens inserted
padded_placeholder_tokens = []
for token in visual_placeholders:
if token == image_atom_token_id:
padded_placeholder_tokens.extend([image_padding_token_id] *
num_image_atoms)
else:
padded_placeholder_tokens.append(image_padding_token_id)
return padded_placeholder_tokens
def preprocess_multidata(
self,
images: Optional[Union[PIL.Image.Image, list[PIL.Image.Image]]] = None,
video: Optional[Union[list[PIL.Image.Image], np.ndarray]] = None,
convert_to_rgb: Optional[bool] = True,
min_pixels: int = MIN_PIXELS,
max_pixels: int = MAX_PIXELS,
return_tensors: Optional[str] = 'pt',
):
is_video = False
if images is not None:
if not isinstance(images, list):
images = [images]
elif video is not None:
is_video = True
# type of vidoe in dummy_mm_data is np.ndarray
if isinstance(video, np.ndarray):
images = []
for i in range(video.shape[0]):
image = PIL.Image.fromarray(video[i].astype(np.uint8))
images.append(image)
elif isinstance(video, list):
images = video
min_pixels = min(max_pixels if max_pixels is not None else MAX_PIXELS,
min_pixels if min_pixels is not None else MIN_PIXELS)
images = [
image.convert("RGB")
if convert_to_rgb and image.mode != 'RGB' else image
for image in images
]
width, height = images[0].size
resized_height, resized_width = height, width
processed_images = []
for image in images:
resized_height, resized_width = self.smart_resize(
height,
width,
factor=self.patch_size * self.hidden_stride,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
new_size = dict(height=resized_height, width=resized_width)
image_pt = self.image_processor.preprocess(
image, size=new_size, return_tensors="np")['pixel_values'][0]
processed_images.append(image_pt)
patches = np.array(processed_images)
if patches.shape[0] % self.temporal_patch_size != 0:
num_to_pad = self.temporal_patch_size - (patches.shape[0] %
self.temporal_patch_size)
repeats = np.repeat(patches[-1][np.newaxis], num_to_pad, axis=0)
patches = np.concatenate([patches, repeats], axis=0)
channel = patches.shape[1]
grid_t = patches.shape[0] // self.temporal_patch_size
grid_h = resized_height // self.patch_size
grid_w = resized_width // self.patch_size
patches = patches.reshape(
grid_t,
self.temporal_patch_size,
channel,
grid_h // self.hidden_stride,
self.hidden_stride,
self.patch_size,
grid_w // self.hidden_stride,
self.hidden_stride,
self.patch_size,
)
patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)
flatten_patches = patches.reshape(
grid_t * grid_h * grid_w, channel * self.temporal_patch_size *
self.patch_size * self.patch_size)
visual_placeholders = self.construct_visual_placeholders(
[grid_t, grid_h, grid_w], is_video)
return torch.tensor(
flatten_patches), visual_placeholders, torch.tensor(
[[grid_t, grid_h, grid_w]])
AutoProcessor.register("Ovis2_5Processor", Ovis2_5Processor)