mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:15:20 +08:00
[Model] support new model ovis2.5 (#23084)
Signed-off-by: myselvess <244285088@qq.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Isotr0py <2037008807@qq.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
f856c33ce9
commit
b87cb97a53
@ -641,6 +641,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
|||||||
| `MolmoForCausalLM` | Molmo | T + I<sup>+</sup> | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `MolmoForCausalLM` | Molmo | T + I<sup>+</sup> | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `NVLM_D_Model` | NVLM-D 1.0 | T + I<sup>+</sup> | `nvidia/NVLM-D-72B`, etc. | | ✅︎ | ✅︎ |
|
| `NVLM_D_Model` | NVLM-D 1.0 | T + I<sup>+</sup> | `nvidia/NVLM-D-72B`, etc. | | ✅︎ | ✅︎ |
|
||||||
| `Ovis` | Ovis2, Ovis1.6 | T + I<sup>+</sup> | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ | ✅︎ |
|
| `Ovis` | Ovis2, Ovis1.6 | T + I<sup>+</sup> | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ | ✅︎ |
|
||||||
|
| `Ovis2_5` | Ovis2.5 | T + I<sup>+</sup> + V | `AIDC-AI/Ovis2.5-9B`, etc. | | | ✅︎ |
|
||||||
| `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + I<sup>E</sup> | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ | ⚠️ |
|
| `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + I<sup>E</sup> | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ | ⚠️ |
|
||||||
| `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + I<sup>E+</sup> | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | ✅︎ |
|
| `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + I<sup>E+</sup> | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | ✅︎ |
|
||||||
| `Phi4MMForCausalLM` | Phi-4-multimodal | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `Phi4MMForCausalLM` | Phi-4-multimodal | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
|
|||||||
@ -1105,6 +1105,38 @@ def run_ovis(questions: list[str], modality: str) -> ModelRequestData:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Ovis2_5
|
||||||
|
def run_ovis2_5(questions: list[str], modality: str) -> ModelRequestData:
|
||||||
|
model_name = "AIDC-AI/Ovis2.5-2B"
|
||||||
|
|
||||||
|
engine_args = EngineArgs(
|
||||||
|
model=model_name,
|
||||||
|
max_model_len=4096,
|
||||||
|
max_num_seqs=2,
|
||||||
|
trust_remote_code=True,
|
||||||
|
dtype="half",
|
||||||
|
limit_mm_per_prompt={modality: 1},
|
||||||
|
)
|
||||||
|
if modality == "image":
|
||||||
|
placeholder = "<image>"
|
||||||
|
elif modality == "video":
|
||||||
|
placeholder = "<video>"
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||||
|
messages = [
|
||||||
|
[{"role": "user", "content": f"{placeholder}\n{question}"}]
|
||||||
|
for question in questions
|
||||||
|
]
|
||||||
|
prompts = tokenizer.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return ModelRequestData(
|
||||||
|
engine_args=engine_args,
|
||||||
|
prompts=prompts,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# PaliGemma
|
# PaliGemma
|
||||||
def run_paligemma(questions: list[str], modality: str) -> ModelRequestData:
|
def run_paligemma(questions: list[str], modality: str) -> ModelRequestData:
|
||||||
assert modality == "image"
|
assert modality == "image"
|
||||||
@ -1579,6 +1611,7 @@ model_example_map = {
|
|||||||
"nemotron_vl": run_nemotron_vl,
|
"nemotron_vl": run_nemotron_vl,
|
||||||
"NVLM_D": run_nvlm_d,
|
"NVLM_D": run_nvlm_d,
|
||||||
"ovis": run_ovis,
|
"ovis": run_ovis,
|
||||||
|
"ovis2_5": run_ovis2_5,
|
||||||
"paligemma": run_paligemma,
|
"paligemma": run_paligemma,
|
||||||
"paligemma2": run_paligemma2,
|
"paligemma2": run_paligemma2,
|
||||||
"phi3_v": run_phi3v,
|
"phi3_v": run_phi3v,
|
||||||
|
|||||||
@ -680,6 +680,36 @@ def load_ovis(question: str, image_urls: list[str]) -> ModelRequestData:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ovis2_5
|
||||||
|
def load_ovis2_5(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||||
|
model_name = "AIDC-AI/Ovis2.5-2B"
|
||||||
|
|
||||||
|
engine_args = EngineArgs(
|
||||||
|
model=model_name,
|
||||||
|
max_model_len=8192,
|
||||||
|
max_num_seqs=2,
|
||||||
|
trust_remote_code=True,
|
||||||
|
dtype="half",
|
||||||
|
limit_mm_per_prompt={"image": len(image_urls)},
|
||||||
|
)
|
||||||
|
|
||||||
|
placeholders = "\n".join(
|
||||||
|
f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1)
|
||||||
|
)
|
||||||
|
messages = [{"role": "user", "content": f"{placeholders}\n{question}"}]
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||||
|
prompt = tokenizer.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return ModelRequestData(
|
||||||
|
engine_args=engine_args,
|
||||||
|
prompt=prompt,
|
||||||
|
image_data=[fetch_image(url) for url in image_urls],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_pixtral_hf(question: str, image_urls: list[str]) -> ModelRequestData:
|
def load_pixtral_hf(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||||
model_name = "mistral-community/pixtral-12b"
|
model_name = "mistral-community/pixtral-12b"
|
||||||
|
|
||||||
@ -1155,6 +1185,7 @@ model_example_map = {
|
|||||||
"mllama": load_mllama,
|
"mllama": load_mllama,
|
||||||
"NVLM_D": load_nvlm_d,
|
"NVLM_D": load_nvlm_d,
|
||||||
"ovis": load_ovis,
|
"ovis": load_ovis,
|
||||||
|
"ovis2_5": load_ovis2_5,
|
||||||
"phi3_v": load_phi3v,
|
"phi3_v": load_phi3v,
|
||||||
"phi4_mm": load_phi4mm,
|
"phi4_mm": load_phi4mm,
|
||||||
"phi4_multimodal": load_phi4_multimodal,
|
"phi4_multimodal": load_phi4_multimodal,
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from pathlib import PosixPath
|
|||||||
import pytest
|
import pytest
|
||||||
from transformers import (AutoModel, AutoModelForImageTextToText,
|
from transformers import (AutoModel, AutoModelForImageTextToText,
|
||||||
AutoModelForTextToWaveform, AutoModelForVision2Seq)
|
AutoModelForTextToWaveform, AutoModelForVision2Seq)
|
||||||
|
from transformers.utils import is_flash_attn_2_available
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import identity
|
from vllm.utils import identity
|
||||||
@ -621,6 +622,26 @@ VLM_TEST_SETTINGS = {
|
|||||||
hf_model_kwargs={"llm_attn_implementation": "sdpa"},
|
hf_model_kwargs={"llm_attn_implementation": "sdpa"},
|
||||||
patch_hf_runner=model_utils.ovis_patch_hf_runner,
|
patch_hf_runner=model_utils.ovis_patch_hf_runner,
|
||||||
),
|
),
|
||||||
|
"ovis2_5": VLMTestInfo(
|
||||||
|
models=["AIDC-AI/Ovis2.5-2B"],
|
||||||
|
test_type=(
|
||||||
|
VLMTestType.IMAGE,
|
||||||
|
VLMTestType.MULTI_IMAGE,
|
||||||
|
VLMTestType.VIDEO
|
||||||
|
),
|
||||||
|
prompt_formatter=lambda img_prompt: f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
|
||||||
|
img_idx_to_prompt=lambda idx: "<image>\n", # noqa: E501
|
||||||
|
video_idx_to_prompt=lambda idx: "<video>\n",
|
||||||
|
max_model_len=4096,
|
||||||
|
max_num_seqs=2,
|
||||||
|
dtype="half",
|
||||||
|
num_logprobs=10,
|
||||||
|
patch_hf_runner=model_utils.ovis2_5_patch_hf_runner,
|
||||||
|
marks=[pytest.mark.skipif(
|
||||||
|
not is_flash_attn_2_available(),
|
||||||
|
reason="HF model needs `flash_attn` installed"
|
||||||
|
)],
|
||||||
|
),
|
||||||
"phi3v": VLMTestInfo(
|
"phi3v": VLMTestInfo(
|
||||||
models=["microsoft/Phi-3.5-vision-instruct"],
|
models=["microsoft/Phi-3.5-vision-instruct"],
|
||||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from typing import Optional, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
|
import PIL.Image
|
||||||
import pytest
|
import pytest
|
||||||
import regex as re
|
import regex as re
|
||||||
import torch
|
import torch
|
||||||
@ -810,6 +811,63 @@ def ovis_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
|||||||
return hf_model
|
return hf_model
|
||||||
|
|
||||||
|
|
||||||
|
def ovis2_5_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||||
|
"""Patches and returns an instance of the HfRunner to use for Ovis2."""
|
||||||
|
hf_model.model.get_output_embeddings = lambda: \
|
||||||
|
hf_model.model.llm.get_output_embeddings()
|
||||||
|
|
||||||
|
def processor(*args, text="", images=None, videos=None, **kwargs):
|
||||||
|
if images is None:
|
||||||
|
images = []
|
||||||
|
else:
|
||||||
|
images = [images] if isinstance(images, Image) else images
|
||||||
|
if videos is None:
|
||||||
|
videos = []
|
||||||
|
else:
|
||||||
|
videos = [videos] if isinstance(videos, np.ndarray) else videos
|
||||||
|
videos = [[PIL.Image.fromarray(frame) for frame in vid]
|
||||||
|
for vid in videos]
|
||||||
|
|
||||||
|
prompt_start_and_end = {
|
||||||
|
"qwen2": ("<|im_start|>user\n", "<|im_end|>\n"),
|
||||||
|
"llama":
|
||||||
|
("<|start_header_id|>user<|end_header_id|>\n\n", "<|eot_id|>"),
|
||||||
|
"gemma2": ("<start_of_turn>user\n", "<end_of_turn>\n"),
|
||||||
|
}
|
||||||
|
for start, end in prompt_start_and_end.values():
|
||||||
|
if start in text and end in text:
|
||||||
|
text = text.split(start)[1].split(end)[0]
|
||||||
|
break
|
||||||
|
|
||||||
|
images_message = [{"type": "image", "image": img} for img in images]
|
||||||
|
videos_message = [{"type": "video", "video": vid} for vid in videos]
|
||||||
|
|
||||||
|
messages = [{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [
|
||||||
|
*images_message,
|
||||||
|
*videos_message,
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": text
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}]
|
||||||
|
|
||||||
|
input_ids, pixel_values, grid_thws = hf_model.model.preprocess_inputs(
|
||||||
|
messages=messages, enable_thinking=True)
|
||||||
|
inputs = {
|
||||||
|
"inputs": input_ids,
|
||||||
|
"pixel_values": pixel_values,
|
||||||
|
"grid_thws": grid_thws,
|
||||||
|
}
|
||||||
|
return BatchFeature(data=inputs, tensor_type="pt")
|
||||||
|
|
||||||
|
hf_model.processor = processor
|
||||||
|
return hf_model
|
||||||
|
|
||||||
|
|
||||||
def qwen2_5_omni_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
def qwen2_5_omni_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||||
"""Patches and returns an instance of the HfRunner for Qwen2.5-Omni."""
|
"""Patches and returns an instance of the HfRunner for Qwen2.5-Omni."""
|
||||||
thinker = hf_model.model.thinker
|
thinker = hf_model.model.thinker
|
||||||
|
|||||||
@ -162,6 +162,7 @@ def _test_processing_correctness(
|
|||||||
_ADD_SPECIAL_TOKENS_OVERRIDES = {
|
_ADD_SPECIAL_TOKENS_OVERRIDES = {
|
||||||
"mllama": False,
|
"mllama": False,
|
||||||
"ovis": False,
|
"ovis": False,
|
||||||
|
"ovis2_5": False,
|
||||||
"paligemma": False,
|
"paligemma": False,
|
||||||
"ultravox": False,
|
"ultravox": False,
|
||||||
"whisper": False,
|
"whisper": False,
|
||||||
@ -301,6 +302,7 @@ def _test_processing_correctness_one(
|
|||||||
"AIDC-AI/Ovis1.6-Gemma2-9B",
|
"AIDC-AI/Ovis1.6-Gemma2-9B",
|
||||||
"AIDC-AI/Ovis1.6-Llama3.2-3B",
|
"AIDC-AI/Ovis1.6-Llama3.2-3B",
|
||||||
"AIDC-AI/Ovis2-1B",
|
"AIDC-AI/Ovis2-1B",
|
||||||
|
"AIDC-AI/Ovis2.5-2B",
|
||||||
"google/paligemma-3b-mix-224",
|
"google/paligemma-3b-mix-224",
|
||||||
"google/paligemma2-3b-ft-docci-448",
|
"google/paligemma2-3b-ft-docci-448",
|
||||||
"microsoft/Phi-3.5-vision-instruct",
|
"microsoft/Phi-3.5-vision-instruct",
|
||||||
|
|||||||
@ -464,6 +464,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
|||||||
transformers_version_reason="HF model is not compatible", # noqa: E501
|
transformers_version_reason="HF model is not compatible", # noqa: E501
|
||||||
extras={"1.6-llama": "AIDC-AI/Ovis1.6-Llama3.2-3B",
|
extras={"1.6-llama": "AIDC-AI/Ovis1.6-Llama3.2-3B",
|
||||||
"1.6-gemma": "AIDC-AI/Ovis1.6-Gemma2-9B"}), # noqa: E501
|
"1.6-gemma": "AIDC-AI/Ovis1.6-Gemma2-9B"}), # noqa: E501
|
||||||
|
"Ovis2_5": _HfExamplesInfo("AIDC-AI/Ovis2.5-2B", trust_remote_code=True,
|
||||||
|
max_transformers_version="4.53",
|
||||||
|
transformers_version_reason="HF model is not compatible"), # noqa: E501
|
||||||
"PaliGemmaForConditionalGeneration": _HfExamplesInfo("google/paligemma-3b-mix-224", # noqa: E501
|
"PaliGemmaForConditionalGeneration": _HfExamplesInfo("google/paligemma-3b-mix-224", # noqa: E501
|
||||||
extras={"v2": "google/paligemma2-3b-ft-docci-448"}), # noqa: E501
|
extras={"v2": "google/paligemma2-3b-ft-docci-448"}), # noqa: E501
|
||||||
"Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct",
|
"Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct",
|
||||||
|
|||||||
570
vllm/model_executor/models/ovis2_5.py
Normal file
570
vllm/model_executor/models/ovis2_5.py
Normal file
@ -0,0 +1,570 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
""" PyTorch Ovis model."""
|
||||||
|
from collections.abc import Iterable, Mapping
|
||||||
|
from functools import partial
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig)
|
||||||
|
from vllm.model_executor.models.ovis import (OvisImagePatchInputs,
|
||||||
|
VisualEmbedding)
|
||||||
|
from vllm.model_executor.models.siglip2navit import Siglip2NavitModel
|
||||||
|
from vllm.model_executor.models.utils import (AutoWeightsLoader, flatten_bn,
|
||||||
|
init_vllm_registered_model,
|
||||||
|
maybe_prefix)
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||||
|
MultiModalKwargsItems)
|
||||||
|
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
|
||||||
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
|
BaseProcessingInfo, PromptReplacement)
|
||||||
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor
|
||||||
|
|
||||||
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal
|
||||||
|
|
||||||
|
IMAGE_TOKEN = "<image>"
|
||||||
|
VIDEO_TOKEN = "<video>"
|
||||||
|
INDICATOR_IDS = [-301, -302, -303, -304]
|
||||||
|
|
||||||
|
IMAGE_PAD_TOKEN_MAP = {
|
||||||
|
"gemma2": "<unused0>",
|
||||||
|
"llama": "<|reserved_special_token_0|>",
|
||||||
|
"qwen2": "<|image_pad|>",
|
||||||
|
"qwen3": "<|image_pad|>",
|
||||||
|
}
|
||||||
|
IMAGE_PAD_TOKEN_ID_MAP = {
|
||||||
|
"gemma2": 7,
|
||||||
|
"llama": 128002,
|
||||||
|
"qwen2": 151655,
|
||||||
|
"qwen3": 151655,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _ovis2_5_field_config():
|
||||||
|
return dict(pixel_values=MultiModalFieldConfig.batched("image"),
|
||||||
|
grids=MultiModalFieldConfig.batched("image"),
|
||||||
|
indicator_tokens=MultiModalFieldConfig.batched("image"),
|
||||||
|
video_pixel_values=MultiModalFieldConfig.batched("video"),
|
||||||
|
video_indicator_tokens=MultiModalFieldConfig.batched("video"),
|
||||||
|
video_grids=MultiModalFieldConfig.batched("video"))
|
||||||
|
|
||||||
|
|
||||||
|
class VisualTokenizer(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
VIT
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
visual_vocab_size: int,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.vit = self._init_backbone(
|
||||||
|
config=config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.vit",
|
||||||
|
)
|
||||||
|
# reserved tokens for INDICATOR_IDS
|
||||||
|
head_dim = visual_vocab_size - len(INDICATOR_IDS)
|
||||||
|
self.head = torch.nn.Sequential(
|
||||||
|
ReplicatedLinear(
|
||||||
|
self.config.hidden_size * self.config.hidden_stride**2,
|
||||||
|
head_dim,
|
||||||
|
bias=False,
|
||||||
|
return_bias=False,
|
||||||
|
), torch.nn.LayerNorm(head_dim))
|
||||||
|
|
||||||
|
def _init_backbone(
|
||||||
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
|
model_type = config.model_type
|
||||||
|
if model_type == "siglip2_navit":
|
||||||
|
return Siglip2NavitModel(config=config, )
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported visual tokenizer model_type: {model_type}")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
return next(self.head.parameters()).dtype
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return next(self.head.parameters()).device
|
||||||
|
|
||||||
|
def tokenize(self, logits):
|
||||||
|
tokens = torch.softmax(logits, dim=-1,
|
||||||
|
dtype=torch.float32).to(logits.dtype)
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def encode(self, pixel_values, grid_thws):
|
||||||
|
features = self.vit(pixel_values,
|
||||||
|
grid_thws,
|
||||||
|
output_hidden_states=True,
|
||||||
|
return_dict=True)
|
||||||
|
# refer to qwen2.5-vl patchmerger
|
||||||
|
seq_len, _ = features.shape
|
||||||
|
features = features.reshape(seq_len // (self.config.hidden_stride**2),
|
||||||
|
-1)
|
||||||
|
|
||||||
|
return features
|
||||||
|
|
||||||
|
def forward(self, pixel_values, grid_thws) -> torch.Tensor:
|
||||||
|
features = self.encode(pixel_values, grid_thws)
|
||||||
|
logits = self.head(features)
|
||||||
|
tokens = self.tokenize(logits)
|
||||||
|
# tokens' shape is [#Token, VocabSize-4],
|
||||||
|
# so padding with [#Token, 4], after which,
|
||||||
|
# tokens' shape should become [#Token, VocabSize];
|
||||||
|
tokens = torch.nn.functional.pad(
|
||||||
|
tokens,
|
||||||
|
(0, len(INDICATOR_IDS)),
|
||||||
|
mode="constant",
|
||||||
|
value=0,
|
||||||
|
)
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
class Ovis2_5ProcessingInfo(BaseProcessingInfo):
|
||||||
|
|
||||||
|
def get_hf_config(self):
|
||||||
|
return self.ctx.get_hf_config()
|
||||||
|
|
||||||
|
def get_hf_processor(self, **kwargs):
|
||||||
|
vit_config = self.get_hf_config().vit_config
|
||||||
|
return self.ctx.get_hf_processor(
|
||||||
|
Ovis2_5Processor,
|
||||||
|
image_pad_token=self.get_image_pad_token(),
|
||||||
|
patch_size=vit_config.patch_size,
|
||||||
|
hidden_stride=vit_config.hidden_stride,
|
||||||
|
temporal_patch_size=vit_config.temporal_patch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_image_pad_token(self) -> str:
|
||||||
|
hf_text_config = self.get_hf_config().get_text_config()
|
||||||
|
text_model_type = hf_text_config.model_type
|
||||||
|
return IMAGE_PAD_TOKEN_MAP.get(text_model_type)
|
||||||
|
|
||||||
|
def get_image_processor(self) -> BaseImageProcessor:
|
||||||
|
return self.get_hf_processor().image_processor # type: ignore
|
||||||
|
|
||||||
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||||
|
return {"image": None, "video": 1}
|
||||||
|
|
||||||
|
def get_image_size_with_most_features(self) -> ImageSize:
|
||||||
|
# NOTE(myselvess): max_pixels 1792 * 1792 hardcoded in original code
|
||||||
|
# TODO(myselvess): Be adjusted based on the max_pixels
|
||||||
|
return ImageSize(width=1792, height=1792)
|
||||||
|
|
||||||
|
def get_num_image_tokens(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
image_width: int,
|
||||||
|
image_height: int,
|
||||||
|
num_frames: int = 1,
|
||||||
|
) -> tuple[ImageSize, int]:
|
||||||
|
hf_config = self.get_hf_config()
|
||||||
|
vit_config = hf_config.vit_config
|
||||||
|
patch_size = vit_config.patch_size
|
||||||
|
temporal_patch_size = vit_config.temporal_patch_size
|
||||||
|
# NOTE: Frames are padded to be divisible by `temporal_patch_size`
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
|
||||||
|
padded_num_frames = num_frames + (-num_frames % temporal_patch_size)
|
||||||
|
grid_t = max(padded_num_frames // temporal_patch_size, 1)
|
||||||
|
grid_h = image_height // patch_size
|
||||||
|
grid_w = image_width // patch_size
|
||||||
|
num_patches = grid_t * grid_h * grid_w
|
||||||
|
num_vision_tokens = num_patches
|
||||||
|
return num_vision_tokens
|
||||||
|
|
||||||
|
def get_max_image_tokens(self) -> int:
|
||||||
|
target_width, target_height = self.get_image_size_with_most_features()
|
||||||
|
return self.get_num_image_tokens(image_width=target_width,
|
||||||
|
image_height=target_height)
|
||||||
|
|
||||||
|
def _get_max_video_frames(self, max_tokens: int) -> int:
|
||||||
|
target_width, target_height = self.get_image_size_with_most_features()
|
||||||
|
num_frames = 0
|
||||||
|
while True:
|
||||||
|
next_num_frames = num_frames + 1
|
||||||
|
next_max_tokens = self.get_num_video_tokens(
|
||||||
|
image_width=target_width,
|
||||||
|
image_height=target_height,
|
||||||
|
num_frames=next_num_frames,
|
||||||
|
image_processor=None,
|
||||||
|
)
|
||||||
|
if next_max_tokens > max_tokens:
|
||||||
|
break
|
||||||
|
num_frames = next_num_frames
|
||||||
|
return num_frames
|
||||||
|
|
||||||
|
def get_num_frames_with_most_features(
|
||||||
|
self,
|
||||||
|
seq_len: int,
|
||||||
|
mm_counts: Mapping[str, int],
|
||||||
|
) -> int:
|
||||||
|
max_images = mm_counts.get("image", 0)
|
||||||
|
max_videos = mm_counts.get("video", 0)
|
||||||
|
max_image_tokens = self.get_max_image_tokens() * max_images
|
||||||
|
max_total_frames = self._get_max_video_frames(seq_len -
|
||||||
|
max_image_tokens)
|
||||||
|
max_frames_per_video = max_total_frames // max(max_videos, 1)
|
||||||
|
return max(max_frames_per_video, 1)
|
||||||
|
|
||||||
|
def get_num_video_tokens(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
image_width: int,
|
||||||
|
image_height: int,
|
||||||
|
num_frames: int,
|
||||||
|
image_processor: Optional[BaseImageProcessor],
|
||||||
|
) -> int:
|
||||||
|
num_video_tokens = self.get_num_image_tokens(image_width=image_width,
|
||||||
|
image_height=image_height,
|
||||||
|
num_frames=num_frames)
|
||||||
|
return num_video_tokens
|
||||||
|
|
||||||
|
def get_max_video_tokens(
|
||||||
|
self,
|
||||||
|
seq_len: int,
|
||||||
|
mm_counts: Mapping[str, int],
|
||||||
|
) -> int:
|
||||||
|
target_width, target_height = self.get_image_size_with_most_features()
|
||||||
|
return self.get_num_video_tokens(
|
||||||
|
image_width=target_width,
|
||||||
|
image_height=target_height,
|
||||||
|
num_frames=self.get_num_frames_with_most_features(
|
||||||
|
seq_len, mm_counts),
|
||||||
|
image_processor=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Ovis2_5DummyInputsBuilder(BaseDummyInputsBuilder[Ovis2_5ProcessingInfo]):
|
||||||
|
|
||||||
|
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||||
|
num_images = mm_counts.get("image", 0)
|
||||||
|
num_videos = mm_counts.get("video", 0)
|
||||||
|
return IMAGE_TOKEN * num_images + VIDEO_TOKEN * num_videos
|
||||||
|
|
||||||
|
def get_dummy_mm_data(
|
||||||
|
self,
|
||||||
|
seq_len: int,
|
||||||
|
mm_counts: Mapping[str, int],
|
||||||
|
) -> MultiModalDataDict:
|
||||||
|
num_images = mm_counts.get("image", 0)
|
||||||
|
num_videos = mm_counts.get("video", 0)
|
||||||
|
|
||||||
|
target_width, target_height = \
|
||||||
|
self.info.get_image_size_with_most_features()
|
||||||
|
target_num_frames = \
|
||||||
|
self.info.get_num_frames_with_most_features(seq_len, mm_counts)
|
||||||
|
mm_data = {
|
||||||
|
"image":
|
||||||
|
self._get_dummy_images(width=target_width,
|
||||||
|
height=target_height,
|
||||||
|
num_images=num_images),
|
||||||
|
"video":
|
||||||
|
self._get_dummy_videos(
|
||||||
|
width=target_width,
|
||||||
|
height=target_height,
|
||||||
|
num_frames=target_num_frames,
|
||||||
|
num_videos=num_videos,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return mm_data
|
||||||
|
|
||||||
|
|
||||||
|
class Ovis2_5MultiModalProcessor(BaseMultiModalProcessor[Ovis2_5ProcessingInfo]
|
||||||
|
):
|
||||||
|
|
||||||
|
def visual_indicators_to_visual_tokens(
|
||||||
|
self,
|
||||||
|
visual_indicators: list[int],
|
||||||
|
) -> list[int]:
|
||||||
|
"""
|
||||||
|
Filter image indicators placeholders and convert them to corresponding
|
||||||
|
tokens in visual tokenizer.
|
||||||
|
"""
|
||||||
|
hf_config = self.info.get_hf_config()
|
||||||
|
vte_vocab_size = hf_config.visual_vocab_size
|
||||||
|
return [
|
||||||
|
vte_vocab_size - len(INDICATOR_IDS) + abs(x + 300) - 1
|
||||||
|
for x in visual_indicators if x < -300
|
||||||
|
]
|
||||||
|
|
||||||
|
def _call_hf_processor(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
mm_data: Mapping[str, object],
|
||||||
|
mm_kwargs: Mapping[str, object],
|
||||||
|
tok_kwargs: Mapping[str, object],
|
||||||
|
) -> BatchFeature:
|
||||||
|
if not mm_data:
|
||||||
|
# Avoid warning from HF logger for text-only input
|
||||||
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
|
||||||
|
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
|
||||||
|
|
||||||
|
processed_outputs = super()._call_hf_processor(
|
||||||
|
prompt=prompt,
|
||||||
|
mm_data=mm_data,
|
||||||
|
mm_kwargs=mm_kwargs,
|
||||||
|
tok_kwargs=tok_kwargs,
|
||||||
|
)
|
||||||
|
hf_processor = self.info.get_hf_processor()
|
||||||
|
|
||||||
|
if "videos" in mm_data:
|
||||||
|
visual_indicators = [
|
||||||
|
hf_processor.construct_visual_indicators((1, 1, 1), True)
|
||||||
|
for grid in processed_outputs["video_grids"]
|
||||||
|
]
|
||||||
|
indicator_tokens = [
|
||||||
|
self.visual_indicators_to_visual_tokens(indicator)
|
||||||
|
for indicator in visual_indicators
|
||||||
|
]
|
||||||
|
processed_outputs["video_indicator_tokens"] = indicator_tokens
|
||||||
|
if "images" in mm_data:
|
||||||
|
visual_indicators = [
|
||||||
|
hf_processor.construct_visual_indicators((1, 1, 1), False)
|
||||||
|
for grid in processed_outputs["grids"]
|
||||||
|
]
|
||||||
|
indicator_tokens = [
|
||||||
|
self.visual_indicators_to_visual_tokens(indicator)
|
||||||
|
for indicator in visual_indicators
|
||||||
|
]
|
||||||
|
|
||||||
|
processed_outputs["indicator_tokens"] = indicator_tokens
|
||||||
|
return processed_outputs
|
||||||
|
|
||||||
|
def _apply_hf_processor_tokens_only(
|
||||||
|
self,
|
||||||
|
prompt_tokens: list[int],
|
||||||
|
) -> list[int]:
|
||||||
|
|
||||||
|
return prompt_tokens
|
||||||
|
|
||||||
|
def _get_mm_fields_config(
|
||||||
|
self,
|
||||||
|
hf_inputs: BatchFeature,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
|
return _ovis2_5_field_config()
|
||||||
|
|
||||||
|
def _get_prompt_updates(
|
||||||
|
self,
|
||||||
|
mm_items: MultiModalDataItems,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
out_mm_kwargs: MultiModalKwargsItems,
|
||||||
|
) -> list[PromptReplacement]:
|
||||||
|
|
||||||
|
def get_replacement_ovis(item_idx, modality: str):
|
||||||
|
if modality == "image":
|
||||||
|
out_item = out_mm_kwargs["image"][item_idx]
|
||||||
|
grid = out_item["grids"].data
|
||||||
|
elif modality == "video":
|
||||||
|
out_item = out_mm_kwargs["video"][item_idx]
|
||||||
|
grid = out_item["video_grids"].data
|
||||||
|
hf_processor = self.info.get_hf_processor()
|
||||||
|
return hf_processor.construct_visual_placeholders(grid[0], )
|
||||||
|
|
||||||
|
return [
|
||||||
|
PromptReplacement(
|
||||||
|
modality=modality,
|
||||||
|
target=IMAGE_TOKEN if modality == "image" else VIDEO_TOKEN,
|
||||||
|
replacement=partial(get_replacement_ovis, modality=modality),
|
||||||
|
) for modality in ("image", "video")
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@MULTIMODAL_REGISTRY.register_processor(Ovis2_5MultiModalProcessor,
|
||||||
|
info=Ovis2_5ProcessingInfo,
|
||||||
|
dummy_inputs=Ovis2_5DummyInputsBuilder)
|
||||||
|
class Ovis2_5(nn.Module, SupportsMultiModal):
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
|
||||||
|
self.config: PretrainedConfig = config
|
||||||
|
self.llm = init_vllm_registered_model(
|
||||||
|
vllm_config=vllm_config.with_hf_config(config.text_config),
|
||||||
|
prefix=maybe_prefix(prefix, "llm"),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.visual_tokenizer = VisualTokenizer(
|
||||||
|
config=config.vit_config,
|
||||||
|
visual_vocab_size=config.visual_vocab_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.visual_tokenizer",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.vte = VisualEmbedding(config.visual_vocab_size,
|
||||||
|
config.hidden_size)
|
||||||
|
|
||||||
|
text_model_type = self.config.get_text_config().model_type
|
||||||
|
self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type]
|
||||||
|
|
||||||
|
# TODO(Isotr0py): PP support
|
||||||
|
# self.make_empty_intermediate_tensors = (
|
||||||
|
# self.language_model.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
|
def _parse_and_validate_visual_input(
|
||||||
|
self, is_video,
|
||||||
|
**kwargs: object) -> Optional[OvisImagePatchInputs]:
|
||||||
|
if is_video:
|
||||||
|
pixel_values = kwargs.pop("video_pixel_values", None)
|
||||||
|
indicator_tokens = kwargs.pop("video_indicator_tokens", None)
|
||||||
|
grids = kwargs.pop("video_grids", None)
|
||||||
|
else:
|
||||||
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
|
indicator_tokens = kwargs.pop("indicator_tokens", None)
|
||||||
|
grids = kwargs.pop("grids", None)
|
||||||
|
if pixel_values is None and indicator_tokens is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if pixel_values is not None and indicator_tokens is not None:
|
||||||
|
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||||
|
raise ValueError("Incorrect type of pixel values. "
|
||||||
|
f"Got type: {type(pixel_values)}")
|
||||||
|
|
||||||
|
if not isinstance(indicator_tokens, (torch.Tensor, list)):
|
||||||
|
raise ValueError("Incorrect type of indicator_tokens. "
|
||||||
|
f"Got type: {type(indicator_tokens)}")
|
||||||
|
|
||||||
|
return OvisImagePatchInputs(
|
||||||
|
type="image_patches",
|
||||||
|
flat_data=flatten_bn(flatten_bn(pixel_values), concat=True),
|
||||||
|
patches_per_image=[
|
||||||
|
x.shape[0] // (self.config.vit_config.hidden_stride**2)
|
||||||
|
for x in flatten_bn(pixel_values)
|
||||||
|
],
|
||||||
|
indicator_tokens=flatten_bn(flatten_bn(indicator_tokens),
|
||||||
|
concat=True),
|
||||||
|
grids=flatten_bn(flatten_bn(grids), concat=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
raise AssertionError("This line should be unreachable.")
|
||||||
|
|
||||||
|
def _process_image_input(
|
||||||
|
self, image_input: OvisImagePatchInputs) -> MultiModalEmbeddings:
|
||||||
|
image_patches_flat = image_input["flat_data"]
|
||||||
|
patches_per_image = image_input["patches_per_image"]
|
||||||
|
indicator_tokens = image_input["indicator_tokens"]
|
||||||
|
grid_thws = image_input["grids"]
|
||||||
|
|
||||||
|
indicator_per_image = list(
|
||||||
|
map(lambda x: 2 if x > 1 else x + 2, patches_per_image))
|
||||||
|
|
||||||
|
target_dtype = self.visual_tokenizer.dtype
|
||||||
|
visual_tokens = self.visual_tokenizer(
|
||||||
|
image_patches_flat.to(target_dtype), grid_thws)
|
||||||
|
|
||||||
|
visual_embeds = self.vte(visual_tokens) # 1:1 numeric eq.
|
||||||
|
indicator_embeds = self.vte(indicator_tokens)
|
||||||
|
|
||||||
|
visual_embeds_per_image = visual_embeds.split(patches_per_image, dim=0)
|
||||||
|
indicator_embeds_per_image = indicator_embeds.split(
|
||||||
|
indicator_per_image)
|
||||||
|
|
||||||
|
vision_embeddings = []
|
||||||
|
for indicator, visual in zip(indicator_embeds_per_image,
|
||||||
|
visual_embeds_per_image):
|
||||||
|
vision_embeddings_per_image = []
|
||||||
|
visual = visual.unsqueeze(0)
|
||||||
|
for i in range(visual.shape[0]):
|
||||||
|
vision_embeddings_per_image.append(
|
||||||
|
torch.cat([indicator[i:i + 1], visual[i]], dim=0))
|
||||||
|
vision_embeddings_per_image.append(indicator[i + 1:])
|
||||||
|
vision_embeddings.append(
|
||||||
|
torch.cat(vision_embeddings_per_image, dim=0))
|
||||||
|
return tuple(vision_embeddings)
|
||||||
|
|
||||||
|
def get_multimodal_embeddings(
|
||||||
|
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||||
|
embeddings = []
|
||||||
|
|
||||||
|
# NOTE: _parse_and_validate_visual_input has side-effects and pops
|
||||||
|
# keys from kwargs. We process images first, then videos.
|
||||||
|
image_input = self._parse_and_validate_visual_input(False, **kwargs)
|
||||||
|
if image_input:
|
||||||
|
embeddings.extend(self._process_image_input(image_input))
|
||||||
|
|
||||||
|
video_input = self._parse_and_validate_visual_input(True, **kwargs)
|
||||||
|
if video_input:
|
||||||
|
embeddings.extend(self._process_image_input(video_input))
|
||||||
|
|
||||||
|
return tuple(embeddings) if embeddings else None
|
||||||
|
|
||||||
|
def get_input_embeddings(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
inputs_embeds = self.llm.get_input_embeddings(input_ids)
|
||||||
|
if multimodal_embeddings is not None:
|
||||||
|
tmp = torch.concat(multimodal_embeddings, dim=0)
|
||||||
|
inputs_embeds[input_ids == self.image_pad_token_id] = tmp
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs: object,
|
||||||
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
|
if intermediate_tensors is not None:
|
||||||
|
inputs_embeds = None
|
||||||
|
|
||||||
|
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||||
|
# condition is for v0 compatibility.
|
||||||
|
elif inputs_embeds is None:
|
||||||
|
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||||
|
|
||||||
|
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||||
|
vision_embeddings)
|
||||||
|
input_ids = None
|
||||||
|
|
||||||
|
# up until here we have a inputs_embeds 100% numerical identity
|
||||||
|
# between the OG HF Transformers implementation and ours
|
||||||
|
hidden_states = self.llm(
|
||||||
|
input_ids=input_ids,
|
||||||
|
positions=positions,
|
||||||
|
intermediate_tensors=intermediate_tensors,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
logits = self.llm.compute_logits(hidden_states, sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
|
torch.Tensor]]) -> set[str]:
|
||||||
|
loader = AutoWeightsLoader(self)
|
||||||
|
return loader.load_weights(weights)
|
||||||
|
|
||||||
|
def get_language_model(self) -> torch.nn.Module:
|
||||||
|
return self.llm
|
||||||
@ -231,6 +231,7 @@ _MULTIMODAL_MODELS = {
|
|||||||
"MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
|
"MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
|
||||||
"NVLM_D": ("nvlm_d", "NVLM_D_Model"),
|
"NVLM_D": ("nvlm_d", "NVLM_D_Model"),
|
||||||
"Ovis": ("ovis", "Ovis"),
|
"Ovis": ("ovis", "Ovis"),
|
||||||
|
"Ovis2_5": ("ovis2_5", "Ovis2_5"),
|
||||||
"PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501
|
"PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501
|
||||||
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
||||||
"Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
|
"Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
|
||||||
|
|||||||
607
vllm/model_executor/models/siglip2navit.py
Normal file
607
vllm/model_executor/models/siglip2navit.py
Normal file
@ -0,0 +1,607 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Implementation of SiglipVisionModel intended to be only used
|
||||||
|
within a vision language model."""
|
||||||
|
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithNoAttention
|
||||||
|
|
||||||
|
from vllm.platforms import _Backend
|
||||||
|
|
||||||
|
from .vision import get_vit_attn_backend
|
||||||
|
|
||||||
|
|
||||||
|
class VisionRotaryEmbedding(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
||||||
|
super().__init__()
|
||||||
|
inv_freq = 1.0 / (theta
|
||||||
|
**(torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
|
|
||||||
|
def forward(self, seqlen: int) -> torch.Tensor:
|
||||||
|
seq = torch.arange(seqlen,
|
||||||
|
device=self.inv_freq.device,
|
||||||
|
dtype=self.inv_freq.dtype)
|
||||||
|
freqs = torch.outer(seq, self.inv_freq)
|
||||||
|
return freqs
|
||||||
|
|
||||||
|
|
||||||
|
class Siglip2VisionEmbeddings(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: PretrainedConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.patch_size = config.patch_size
|
||||||
|
self.image_size = config.image_size
|
||||||
|
self.num_patches = config.num_patches
|
||||||
|
self.preserve_original_pe = config.preserve_original_pe
|
||||||
|
self.hidden_stride = config.hidden_stride
|
||||||
|
|
||||||
|
# siglip2 naflex
|
||||||
|
if self.num_patches > 0:
|
||||||
|
self.patch_embedding = nn.Linear(
|
||||||
|
in_features=config.num_channels * self.patch_size *
|
||||||
|
self.patch_size,
|
||||||
|
out_features=self.embed_dim,
|
||||||
|
)
|
||||||
|
if self.preserve_original_pe:
|
||||||
|
self.position_embedding_size = int(self.num_patches**0.5)
|
||||||
|
self.position_embedding = nn.Embedding(self.num_patches,
|
||||||
|
self.embed_dim)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.patch_embedding = nn.Conv2d(
|
||||||
|
in_channels=config.num_channels,
|
||||||
|
out_channels=self.embed_dim,
|
||||||
|
kernel_size=self.patch_size,
|
||||||
|
stride=self.patch_size,
|
||||||
|
padding="valid",
|
||||||
|
)
|
||||||
|
if self.preserve_original_pe:
|
||||||
|
self.num_patches = (self.image_size // self.patch_size)**2
|
||||||
|
self.position_embedding_size = (self.image_size //
|
||||||
|
self.patch_size)
|
||||||
|
self.position_embedding = nn.Embedding(self.num_patches,
|
||||||
|
self.embed_dim)
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
pixel_values: torch.FloatTensor,
|
||||||
|
grid_thws: Optional[torch.LongTensor] = None) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
pixel_values (`torch.FloatTensor`):
|
||||||
|
Pixel values of shape (
|
||||||
|
num_patches,
|
||||||
|
num_channels * temporal_patch_size * patch_size * patch_size
|
||||||
|
)
|
||||||
|
grid_thws: (`torch.LongTensor`):
|
||||||
|
grid shape (num_patches, 3)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Apply patch embeddings to already patchified pixel values
|
||||||
|
target_dtype = self.patch_embedding.weight.dtype
|
||||||
|
if isinstance(self.patch_embedding, nn.Linear):
|
||||||
|
patch_embeds = self.patch_embedding(
|
||||||
|
pixel_values.to(dtype=target_dtype))
|
||||||
|
elif isinstance(self.patch_embedding, nn.Conv2d):
|
||||||
|
pixel_values = pixel_values.view(
|
||||||
|
-1, self.config.num_channels * self.config.temporal_patch_size,
|
||||||
|
self.patch_size, self.patch_size)
|
||||||
|
patch_embeds = self.patch_embedding(
|
||||||
|
pixel_values.to(dtype=target_dtype))
|
||||||
|
patch_embeds = patch_embeds.reshape(-1, self.embed_dim)
|
||||||
|
|
||||||
|
if self.preserve_original_pe:
|
||||||
|
assert grid_thws is not None
|
||||||
|
pos_embed_new = torch.zeros_like(patch_embeds)
|
||||||
|
positional_embeddings = self.position_embedding.weight.reshape(
|
||||||
|
self.position_embedding_size, self.position_embedding_size,
|
||||||
|
-1).unsqueeze(0).permute(0, 3, 1, 2)
|
||||||
|
cnt = 0
|
||||||
|
for t, h, w in grid_thws:
|
||||||
|
volume = t * h * w
|
||||||
|
pe = F.interpolate(positional_embeddings,
|
||||||
|
size=(h, w),
|
||||||
|
mode='bicubic',
|
||||||
|
align_corners=False)
|
||||||
|
pe = pe.permute(0, 2, 3, 1).reshape(1, h * w, -1)
|
||||||
|
pe = pe[0].repeat(t, 1)
|
||||||
|
pe = pe.reshape(t, h // self.hidden_stride, self.hidden_stride,
|
||||||
|
w // self.hidden_stride, self.hidden_stride,
|
||||||
|
-1)
|
||||||
|
pe = pe.permute(0, 1, 3, 2, 4, 5).reshape(volume, -1)
|
||||||
|
pos_embed_new[cnt:cnt + volume] = pe
|
||||||
|
cnt += volume
|
||||||
|
patch_embeds = patch_embeds + pos_embed_new
|
||||||
|
|
||||||
|
return patch_embeds
|
||||||
|
|
||||||
|
|
||||||
|
# copy from flash_attn/layers/rotary.py
|
||||||
|
def rotate_half(x, interleaved=False):
|
||||||
|
if not interleaved:
|
||||||
|
x1, x2 = x.chunk(2, dim=-1)
|
||||||
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
else:
|
||||||
|
x1, x2 = x[..., ::2], x[..., 1::2]
|
||||||
|
return rearrange(torch.stack((-x2, x1), dim=-1),
|
||||||
|
"... d two -> ... (d two)",
|
||||||
|
two=2)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
|
||||||
|
"""
|
||||||
|
x: (batch_size, seqlen, nheads, headdim)
|
||||||
|
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
|
||||||
|
"""
|
||||||
|
ro_dim = cos.shape[-1] * 2
|
||||||
|
assert ro_dim <= x.shape[-1]
|
||||||
|
cos = repeat(
|
||||||
|
cos,
|
||||||
|
"... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
|
||||||
|
sin = repeat(
|
||||||
|
sin,
|
||||||
|
"... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
|
||||||
|
return torch.cat(
|
||||||
|
[
|
||||||
|
x[..., :ro_dim] * cos +
|
||||||
|
rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
is_flash_attn_backend: bool = False,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
cos = cos.chunk(2, dim=-1)[0].contiguous()
|
||||||
|
sin = sin.chunk(2, dim=-1)[0].contiguous()
|
||||||
|
if is_flash_attn_backend:
|
||||||
|
from flash_attn.layers.rotary import apply_rotary_emb
|
||||||
|
apply_rotary_emb_func = apply_rotary_emb
|
||||||
|
else:
|
||||||
|
apply_rotary_emb_func = apply_rotary_emb_torch
|
||||||
|
q_embed = apply_rotary_emb_func(q.float(), cos.float(),
|
||||||
|
sin.float()).type_as(q)
|
||||||
|
k_embed = apply_rotary_emb_func(k.float(), cos.float(),
|
||||||
|
sin.float()).type_as(k)
|
||||||
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
|
class Siglip2Attention(nn.Module):
|
||||||
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_dim = self.embed_dim // self.num_heads
|
||||||
|
if self.head_dim * self.num_heads != self.embed_dim:
|
||||||
|
raise ValueError(
|
||||||
|
f"embed_dim must be divisible by num_heads "
|
||||||
|
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||||
|
f" {self.num_heads}).")
|
||||||
|
self.scale = self.head_dim**-0.5
|
||||||
|
self.dropout = config.attention_dropout
|
||||||
|
self.is_causal = False
|
||||||
|
|
||||||
|
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||||
|
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||||
|
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||||
|
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||||
|
|
||||||
|
self.use_rope = config.use_rope
|
||||||
|
|
||||||
|
# Detect attention implementation.
|
||||||
|
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
|
||||||
|
if self.attn_backend not in {
|
||||||
|
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA,
|
||||||
|
_Backend.ROCM_AITER_FA
|
||||||
|
}:
|
||||||
|
self.attn_backend = _Backend.TORCH_SDPA
|
||||||
|
self.is_flash_attn_backend = self.attn_backend in {
|
||||||
|
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
|
||||||
|
}
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
cu_seqlens: torch.Tensor,
|
||||||
|
position_embeddings: Optional[tuple[torch.Tensor,
|
||||||
|
torch.Tensor]] = None,
|
||||||
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
|
||||||
|
seq_length, embed_dim = hidden_states.shape
|
||||||
|
|
||||||
|
queries = self.q_proj(hidden_states)
|
||||||
|
keys = self.k_proj(hidden_states)
|
||||||
|
values = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
queries = queries.view(seq_length, self.num_heads, self.head_dim)
|
||||||
|
keys = keys.view(seq_length, self.num_heads, self.head_dim)
|
||||||
|
values = values.view(seq_length, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
|
if self.use_rope:
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
queries, keys = apply_rotary_pos_emb(queries.unsqueeze(0),
|
||||||
|
keys.unsqueeze(0), cos, sin,
|
||||||
|
self.is_flash_attn_backend)
|
||||||
|
queries = queries.squeeze(0)
|
||||||
|
keys = keys.squeeze(0)
|
||||||
|
|
||||||
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||||
|
if self.is_flash_attn_backend:
|
||||||
|
if self.attn_backend == _Backend.ROCM_AITER_FA:
|
||||||
|
from aiter import flash_attn_varlen_func
|
||||||
|
else:
|
||||||
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
attn_output = flash_attn_varlen_func(
|
||||||
|
queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen,
|
||||||
|
max_seqlen).reshape(seq_length, -1)
|
||||||
|
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||||
|
# Execute attention entry by entry for speed & less VRAM.
|
||||||
|
batch_size = cu_seqlens.shape[0] - 1
|
||||||
|
outputs = []
|
||||||
|
cu = cu_seqlens.tolist()
|
||||||
|
for i in range(batch_size):
|
||||||
|
start_idx = cu[i]
|
||||||
|
end_idx = cu[i + 1]
|
||||||
|
|
||||||
|
# Each sequence is processed independently.
|
||||||
|
q_i = queries[start_idx:end_idx].unsqueeze(0)
|
||||||
|
k_i = keys[start_idx:end_idx].unsqueeze(0)
|
||||||
|
v_i = values[start_idx:end_idx].unsqueeze(0)
|
||||||
|
|
||||||
|
# (1, seq_len, num_heads, head_dim) ->
|
||||||
|
# (1, num_heads, seq_len, head_dim)
|
||||||
|
q_i, k_i, v_i = [x.transpose(1, 2) for x in (q_i, k_i, v_i)]
|
||||||
|
|
||||||
|
output_i = F.scaled_dot_product_attention(q_i,
|
||||||
|
k_i,
|
||||||
|
v_i,
|
||||||
|
dropout_p=0.0)
|
||||||
|
# (1, num_heads, seq_len, head_dim) -> (seq_len, embed_dim)
|
||||||
|
output_i = output_i.transpose(1, 2).reshape(-1, self.embed_dim)
|
||||||
|
outputs.append(output_i)
|
||||||
|
|
||||||
|
attn_output = torch.cat(outputs, dim=0)
|
||||||
|
attn_output = self.out_proj(attn_output)
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
class Siglip2MLP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.activation_fn = ACT2FN[config.hidden_act]
|
||||||
|
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||||
|
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_states = self.fc1(hidden_states)
|
||||||
|
hidden_states = self.activation_fn(hidden_states)
|
||||||
|
hidden_states = self.fc2(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Siglip2EncoderLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: PretrainedConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
|
||||||
|
eps=config.layer_norm_eps)
|
||||||
|
self.self_attn = Siglip2Attention(config)
|
||||||
|
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
|
||||||
|
eps=config.layer_norm_eps)
|
||||||
|
self.mlp = Siglip2MLP(config)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
|
||||||
|
position_embeddings: torch.Tensor) -> tuple[torch.FloatTensor]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor`):
|
||||||
|
Input to the layer of shape `(batch, seq_len, embed_dim)`.
|
||||||
|
output_attentions (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to return the attentions tensors of all
|
||||||
|
attention layers. See `attentions` under
|
||||||
|
returned tensors for more detail.
|
||||||
|
"""
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.layer_norm1(hidden_states)
|
||||||
|
hidden_states = self.self_attn(hidden_states=hidden_states,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
position_embeddings=position_embeddings)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.layer_norm2(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Siglip2Encoder(nn.Module):
|
||||||
|
"""
|
||||||
|
Transformer encoder consisting of `config.num_hidden_layers`
|
||||||
|
self attention layers. Each layer is a [`Siglip2EncoderLayer`].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: PretrainedConfig
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: PretrainedConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
Siglip2EncoderLayer(config)
|
||||||
|
for _ in range(config.num_hidden_layers)
|
||||||
|
])
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
self.rotary_pos_emb = VisionRotaryEmbedding(
|
||||||
|
config.hidden_size // config.num_attention_heads // 2)
|
||||||
|
self.patch_size = config.patch_size
|
||||||
|
self.hidden_stride = config.hidden_stride
|
||||||
|
self.window_size = config.window_size
|
||||||
|
self.spatial_merge_unit = config.hidden_stride * config.hidden_stride
|
||||||
|
if config.fullatt_block_indexes is None:
|
||||||
|
self.fullatt_block_indexes = None
|
||||||
|
else:
|
||||||
|
self.fullatt_block_indexes = [
|
||||||
|
int(i) for i in config.fullatt_block_indexes.split('|')
|
||||||
|
]
|
||||||
|
|
||||||
|
# copied from qwen2.5_vl
|
||||||
|
def rot_pos_emb(self, grid_thw):
|
||||||
|
pos_ids = []
|
||||||
|
for t, h, w in grid_thw:
|
||||||
|
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
||||||
|
hpos_ids = hpos_ids.reshape(
|
||||||
|
h // self.hidden_stride,
|
||||||
|
self.hidden_stride,
|
||||||
|
w // self.hidden_stride,
|
||||||
|
self.hidden_stride,
|
||||||
|
)
|
||||||
|
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
|
||||||
|
hpos_ids = hpos_ids.flatten()
|
||||||
|
|
||||||
|
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
||||||
|
wpos_ids = wpos_ids.reshape(
|
||||||
|
h // self.hidden_stride,
|
||||||
|
self.hidden_stride,
|
||||||
|
w // self.hidden_stride,
|
||||||
|
self.hidden_stride,
|
||||||
|
)
|
||||||
|
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
|
||||||
|
wpos_ids = wpos_ids.flatten()
|
||||||
|
pos_ids.append(
|
||||||
|
torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
||||||
|
pos_ids = torch.cat(pos_ids, dim=0)
|
||||||
|
max_grid_size = grid_thw[:, 1:].max()
|
||||||
|
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
||||||
|
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||||
|
return rotary_pos_emb
|
||||||
|
|
||||||
|
def get_window_index(self, grid_thw):
|
||||||
|
window_index: list = []
|
||||||
|
cu_window_seqlens: list = [0]
|
||||||
|
window_index_id = 0
|
||||||
|
# patch (after merge) number in each window
|
||||||
|
vit_merger_window_size = (self.window_size // self.hidden_stride //
|
||||||
|
self.patch_size)
|
||||||
|
|
||||||
|
for grid_t, grid_h, grid_w in grid_thw:
|
||||||
|
llm_grid_h, llm_grid_w = (
|
||||||
|
grid_h // self.hidden_stride, # number of patch after merge
|
||||||
|
grid_w // self.hidden_stride,
|
||||||
|
)
|
||||||
|
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
|
||||||
|
grid_t, llm_grid_h, llm_grid_w)
|
||||||
|
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
|
||||||
|
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
|
||||||
|
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
|
||||||
|
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
|
||||||
|
index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
|
||||||
|
index_padded = index_padded.reshape(
|
||||||
|
grid_t,
|
||||||
|
num_windows_h,
|
||||||
|
vit_merger_window_size,
|
||||||
|
num_windows_w,
|
||||||
|
vit_merger_window_size,
|
||||||
|
)
|
||||||
|
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
|
||||||
|
grid_t,
|
||||||
|
num_windows_h * num_windows_w,
|
||||||
|
vit_merger_window_size,
|
||||||
|
vit_merger_window_size,
|
||||||
|
)
|
||||||
|
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
|
||||||
|
index_padded = index_padded.reshape(-1)
|
||||||
|
index_new = index_padded[index_padded != -100]
|
||||||
|
window_index.append(index_new + window_index_id)
|
||||||
|
cu_seqlens_tmp = seqlens.cumsum(
|
||||||
|
0) * self.spatial_merge_unit + cu_window_seqlens[-1]
|
||||||
|
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
|
||||||
|
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
|
||||||
|
window_index = torch.cat(window_index, dim=0)
|
||||||
|
|
||||||
|
return window_index, cu_window_seqlens
|
||||||
|
|
||||||
|
# Ignore copy
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
inputs_embeds,
|
||||||
|
grid_thws: torch.Tensor,
|
||||||
|
output_hidden_states: bool = False,
|
||||||
|
) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, ...]]]:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
inputs_embeds (`torch.FloatTensor` of shape
|
||||||
|
`(batch_size, sequence_length, hidden_size)`):
|
||||||
|
Optionally, instead of passing `input_ids` you can choose to
|
||||||
|
directly pass an embedded representation. This is useful if
|
||||||
|
you want more control over how to convert `input_ids` indices
|
||||||
|
into associated vectors than the model's internal embedding
|
||||||
|
lookup matrix.
|
||||||
|
grid_thws (`torch.LongTensor`):
|
||||||
|
grid shape (num_patches, 3)
|
||||||
|
output_hidden_states (`bool`, *optional*):
|
||||||
|
Whether or not to return the hidden states of all layers. See
|
||||||
|
`hidden_states` under returned tensors for more detail.
|
||||||
|
return_dict (`bool`, *optional*):
|
||||||
|
Whether or not to return a [`~utils.ModelOutput`] instead of
|
||||||
|
a plain tuple.
|
||||||
|
"""
|
||||||
|
rotary_pos_emb = self.rot_pos_emb(grid_thws)
|
||||||
|
window_index, cu_window_seqlens = self.get_window_index(grid_thws)
|
||||||
|
cu_window_seqlens = torch.tensor(
|
||||||
|
cu_window_seqlens,
|
||||||
|
device=inputs_embeds.device,
|
||||||
|
dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32,
|
||||||
|
)
|
||||||
|
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
|
||||||
|
|
||||||
|
seq_len, _ = inputs_embeds.size()
|
||||||
|
inputs_embeds = inputs_embeds.reshape(
|
||||||
|
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
||||||
|
inputs_embeds = inputs_embeds[window_index, :, :]
|
||||||
|
inputs_embeds = inputs_embeds.reshape(seq_len, -1)
|
||||||
|
rotary_pos_emb = rotary_pos_emb.reshape(
|
||||||
|
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
||||||
|
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
|
||||||
|
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
||||||
|
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||||||
|
position_embeddings = (emb.cos(), emb.sin())
|
||||||
|
|
||||||
|
cu_seqlens = torch.repeat_interleave(
|
||||||
|
grid_thws[:, 1] * grid_thws[:, 2], grid_thws[:, 0]
|
||||||
|
).cumsum(
|
||||||
|
dim=0,
|
||||||
|
# Select dtype based on the following factors:
|
||||||
|
# - FA2 requires that cu_seqlens_q must have dtype int32
|
||||||
|
# - torch.onnx.export requires that cu_seqlens_q must have
|
||||||
|
# same dtype as grid_thw
|
||||||
|
# See https://github.com/huggingface/transformers/pull/34852
|
||||||
|
# for more information
|
||||||
|
dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32,
|
||||||
|
)
|
||||||
|
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||||
|
|
||||||
|
reverse_indices = torch.argsort(window_index)
|
||||||
|
encoder_states = () if output_hidden_states else None
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
for index, block in enumerate(self.layers):
|
||||||
|
if (not self.fullatt_block_indexes
|
||||||
|
or index in self.fullatt_block_indexes):
|
||||||
|
cu_seqlens_tmp = cu_seqlens
|
||||||
|
else:
|
||||||
|
cu_seqlens_tmp = cu_window_seqlens
|
||||||
|
hidden_states = block(hidden_states, cu_seqlens_tmp,
|
||||||
|
position_embeddings)
|
||||||
|
if output_hidden_states:
|
||||||
|
hidden_states_ = hidden_states.reshape(
|
||||||
|
seq_len // self.spatial_merge_unit,
|
||||||
|
self.spatial_merge_unit, -1)
|
||||||
|
encoder_states += (hidden_states_[reverse_indices, :].reshape(
|
||||||
|
seq_len, -1), )
|
||||||
|
# tokens = self.post_trunk_norm(tokens)
|
||||||
|
hidden_states = hidden_states.reshape(
|
||||||
|
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
||||||
|
hidden_states = hidden_states[reverse_indices, :].reshape(seq_len, -1)
|
||||||
|
|
||||||
|
return hidden_states, encoder_states
|
||||||
|
|
||||||
|
|
||||||
|
class Siglip2VisionTransformer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: PretrainedConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
embed_dim = config.hidden_size
|
||||||
|
|
||||||
|
self.embeddings = Siglip2VisionEmbeddings(config)
|
||||||
|
self.encoder = Siglip2Encoder(config)
|
||||||
|
self.post_layernorm = nn.LayerNorm(embed_dim,
|
||||||
|
eps=config.layer_norm_eps)
|
||||||
|
self._use_flash_attention_2 = \
|
||||||
|
(config._attn_implementation == "flash_attention_2")
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.FloatTensor,
|
||||||
|
grid_thws: torch.LongTensor,
|
||||||
|
output_hidden_states: Optional[bool] = True,
|
||||||
|
return_dict: Optional[bool] = True,
|
||||||
|
) -> Union[
|
||||||
|
tuple[torch.Tensor],
|
||||||
|
tuple[torch.Tensor, tuple[torch.Tensor, ...]],
|
||||||
|
BaseModelOutputWithNoAttention,
|
||||||
|
]:
|
||||||
|
r"""
|
||||||
|
spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
|
||||||
|
Tensor containing the spatial dimensions (height, width)
|
||||||
|
of the input images.
|
||||||
|
"""
|
||||||
|
hidden_states = self.embeddings(pixel_values, grid_thws)
|
||||||
|
|
||||||
|
last_hidden_state, hidden_states = self.encoder(
|
||||||
|
hidden_states, grid_thws, output_hidden_states)
|
||||||
|
last_hidden_state = self.post_layernorm(last_hidden_state)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (last_hidden_state, )
|
||||||
|
output += (hidden_states, ) if output_hidden_states else ()
|
||||||
|
return output
|
||||||
|
|
||||||
|
return last_hidden_state
|
||||||
|
|
||||||
|
|
||||||
|
class Siglip2NavitModel(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: PretrainedConfig):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.vision_model = Siglip2VisionTransformer(config)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.FloatTensor,
|
||||||
|
grid_thws: torch.LongTensor,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[
|
||||||
|
tuple[torch.Tensor],
|
||||||
|
tuple[torch.Tensor, tuple[torch.Tensor, ...]],
|
||||||
|
BaseModelOutputWithNoAttention,
|
||||||
|
]:
|
||||||
|
|
||||||
|
if output_hidden_states is None:
|
||||||
|
output_hidden_states = self.config.output_hidden_states
|
||||||
|
if return_dict is None:
|
||||||
|
return_dict = self.config.use_return_dict
|
||||||
|
|
||||||
|
return self.vision_model(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
grid_thws=grid_thws,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
@ -11,5 +11,6 @@ reasons:
|
|||||||
from vllm.transformers_utils.processors.deepseek_vl2 import (
|
from vllm.transformers_utils.processors.deepseek_vl2 import (
|
||||||
DeepseekVLV2Processor)
|
DeepseekVLV2Processor)
|
||||||
from vllm.transformers_utils.processors.ovis import OvisProcessor
|
from vllm.transformers_utils.processors.ovis import OvisProcessor
|
||||||
|
from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor
|
||||||
|
|
||||||
__all__ = ["DeepseekVLV2Processor", "OvisProcessor"]
|
__all__ = ["DeepseekVLV2Processor", "OvisProcessor", "Ovis2_5Processor"]
|
||||||
|
|||||||
458
vllm/transformers_utils/processors/ovis2_5.py
Normal file
458
vllm/transformers_utils/processors/ovis2_5.py
Normal file
@ -0,0 +1,458 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import math
|
||||||
|
from functools import cached_property
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import PIL
|
||||||
|
import torch
|
||||||
|
from transformers import AutoProcessor, BatchFeature
|
||||||
|
from transformers.image_utils import ImageInput
|
||||||
|
from transformers.processing_utils import (ProcessingKwargs, ProcessorMixin,
|
||||||
|
Unpack)
|
||||||
|
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
||||||
|
|
||||||
|
__all__ = ['Ovis2_5Processor']
|
||||||
|
IMAGE_TOKEN = "<image>"
|
||||||
|
VIDEO_TOKEN = "<video>"
|
||||||
|
MIN_PIXELS = 448 * 448
|
||||||
|
MAX_PIXELS = 1792 * 1792
|
||||||
|
|
||||||
|
|
||||||
|
class Ovis2_5ProcessorKwargs(ProcessingKwargs,
|
||||||
|
total=False): # type: ignore[call-arg]
|
||||||
|
_defaults = {
|
||||||
|
"text_kwargs": {
|
||||||
|
"padding": False,
|
||||||
|
},
|
||||||
|
"images_kwargs": {
|
||||||
|
'convert_to_rgb': True,
|
||||||
|
'min_pixels': MIN_PIXELS,
|
||||||
|
'max_pixels': MAX_PIXELS,
|
||||||
|
},
|
||||||
|
"videos_kwargs": {
|
||||||
|
'convert_to_rgb': True,
|
||||||
|
'min_pixels': MIN_PIXELS,
|
||||||
|
'max_pixels': MAX_PIXELS,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Ovis2_5Processor(ProcessorMixin):
|
||||||
|
r"""
|
||||||
|
Constructs a Ovis processor which wraps a Ovis image processor
|
||||||
|
and a Qwen2 tokenizer into a single processor.
|
||||||
|
[`OvisProcessor`] offers all the functionalities of
|
||||||
|
[`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`].
|
||||||
|
See the [`~OvisProcessor.__call__`] and [`~OvisProcessor.decode`]
|
||||||
|
for more information.
|
||||||
|
Args:
|
||||||
|
image_processor ([`Qwen2VLImageProcessor`], *optional*):
|
||||||
|
The image processor is a required input.
|
||||||
|
tokenizer ([`Qwen2TokenizerFast`], *optional*):
|
||||||
|
The tokenizer is a required input.
|
||||||
|
chat_template (`str`, *optional*): A Jinja template which will
|
||||||
|
be used to convert lists of messages in a chat into
|
||||||
|
a tokenizable string.
|
||||||
|
"""
|
||||||
|
|
||||||
|
attributes = ["image_processor", "tokenizer"]
|
||||||
|
valid_kwargs = ["chat_template", "image_pad_token"]
|
||||||
|
|
||||||
|
image_processor_class = "AutoImageProcessor"
|
||||||
|
tokenizer_class = "AutoTokenizer"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
image_processor=None,
|
||||||
|
tokenizer=None,
|
||||||
|
chat_template=None,
|
||||||
|
image_pad_token=None,
|
||||||
|
patch_size=16,
|
||||||
|
hidden_stride=2,
|
||||||
|
temporal_patch_size=1,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.image_token = IMAGE_TOKEN
|
||||||
|
self.video_token = VIDEO_TOKEN
|
||||||
|
self.image_pad_token = "<|image_pad|>"
|
||||||
|
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.hidden_stride = hidden_stride
|
||||||
|
self.temporal_patch_size = temporal_patch_size
|
||||||
|
super().__init__(image_processor,
|
||||||
|
tokenizer,
|
||||||
|
chat_template=chat_template)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def extra_special_tokens(self):
|
||||||
|
image_pad_token_id = self.tokenizer.get_vocab()[self.image_pad_token]
|
||||||
|
extra_special_tokens = {
|
||||||
|
"image_token": -200,
|
||||||
|
"video_token": -201,
|
||||||
|
"visual_atom": -300,
|
||||||
|
"image_start": -301,
|
||||||
|
"image_end": -302,
|
||||||
|
"video_start": -303,
|
||||||
|
"video_end": -304,
|
||||||
|
'image_pad': image_pad_token_id,
|
||||||
|
}
|
||||||
|
return extra_special_tokens
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
images: ImageInput = None,
|
||||||
|
videos: Union[np.ndarray, list[ImageInput]] = None,
|
||||||
|
text: Union[TextInput, PreTokenizedInput, list[TextInput],
|
||||||
|
list[PreTokenizedInput]] = None,
|
||||||
|
**kwargs: Unpack[Ovis2_5ProcessorKwargs],
|
||||||
|
) -> BatchFeature:
|
||||||
|
"""
|
||||||
|
Main method to prepare for the model one or several sequences(s)
|
||||||
|
and image(s). This method forwards the `text`and `kwargs` arguments
|
||||||
|
to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text`
|
||||||
|
is not `None` to encode the text. To prepare the vision inputs,
|
||||||
|
this method forwards the `vision_infos` and `kwrags` arguments to
|
||||||
|
Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`]
|
||||||
|
if `vision_infos` is not `None`.
|
||||||
|
Args:
|
||||||
|
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`,
|
||||||
|
`list[PIL.Image.Image]`, `list[np.ndarray]`,
|
||||||
|
`list[torch.Tensor]`):
|
||||||
|
The image or batch of images to be prepared.
|
||||||
|
Each image can be a PIL image, NumPy array or PyTorch
|
||||||
|
tensor. Both channels-first and channels-last formats
|
||||||
|
are supported.
|
||||||
|
text (`str`, `list[str]`, `list[list[str]]`):
|
||||||
|
The sequence or batch of sequences to be encoded.
|
||||||
|
Each sequence can be a string or a list of strings
|
||||||
|
(pretokenized string). If the sequences are provided as
|
||||||
|
list of strings (pretokenized), you must set
|
||||||
|
`is_split_into_words=True` (to lift the ambiguity with
|
||||||
|
a batch of sequences).
|
||||||
|
videos (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`,
|
||||||
|
`list[torch.Tensor]`):
|
||||||
|
The image or batch of videos to be prepared. Each video
|
||||||
|
can be a 4D NumPy array or PyTorch tensor, or a nested
|
||||||
|
list of 3D frames. Both channels-first and channels-last
|
||||||
|
formats are supported.
|
||||||
|
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||||
|
If set, will return tensors of a particular framework.
|
||||||
|
Acceptable values are:
|
||||||
|
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||||
|
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||||
|
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||||
|
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||||
|
Returns:
|
||||||
|
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||||
|
- **input_ids** -- list of token ids to be fed to a model.
|
||||||
|
Returned when `text` is not `None`.
|
||||||
|
- **attention_mask** -- list of indices specifying which tokens
|
||||||
|
should be attended to by the model (when
|
||||||
|
`return_attention_mask=True` or if *"attention_mask"*
|
||||||
|
is in `self.model_input_names` and if `text` is not `None`).
|
||||||
|
- **pixel_values** -- Pixel values to be fed to a model.
|
||||||
|
Returned when `images` is not `None`.
|
||||||
|
- **pixel_values_videos** -- Pixel values of videos to be fed to
|
||||||
|
a model. Returned when `videos` is not `None`.
|
||||||
|
- **image_grid_thw** -- list of image 3D grid in LLM. Returned
|
||||||
|
when `images` is not `None`.
|
||||||
|
- **video_grid_thw** -- list of video 3D grid in LLM. Returned
|
||||||
|
when `videos` is not `None`.
|
||||||
|
- **second_per_grid_ts** -- list of video seconds per time grid.
|
||||||
|
Returned when `videos` is not `None`.
|
||||||
|
"""
|
||||||
|
output_kwargs = self._merge_kwargs(
|
||||||
|
Ovis2_5ProcessorKwargs,
|
||||||
|
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
# Process all images first
|
||||||
|
visual_features = {}
|
||||||
|
output = BatchFeature()
|
||||||
|
if images is not None:
|
||||||
|
processed_images = []
|
||||||
|
image_placeholders_list = []
|
||||||
|
grids = []
|
||||||
|
# Process each image
|
||||||
|
for image in images if isinstance(images, list) else [images]:
|
||||||
|
pixel_values, image_placeholders, grid = (
|
||||||
|
self.preprocess_multidata(
|
||||||
|
images=image, **output_kwargs["images_kwargs"]))
|
||||||
|
processed_images.append(pixel_values)
|
||||||
|
image_placeholders_list.append(image_placeholders)
|
||||||
|
grids.append(grid)
|
||||||
|
|
||||||
|
# assign all processed images
|
||||||
|
if processed_images:
|
||||||
|
visual_features["image_placeholders"] = image_placeholders_list
|
||||||
|
output["pixel_values"] = processed_images
|
||||||
|
output["grids"] = grids
|
||||||
|
|
||||||
|
if videos is not None:
|
||||||
|
processed_videos = []
|
||||||
|
videos_placeholders_list = []
|
||||||
|
grids = []
|
||||||
|
# Process each video
|
||||||
|
for video in videos if isinstance(videos, list) else [videos]:
|
||||||
|
pixel_values, video_placeholders, grid = (
|
||||||
|
self.preprocess_multidata(
|
||||||
|
video=video, **output_kwargs["videos_kwargs"]))
|
||||||
|
processed_videos.append(pixel_values)
|
||||||
|
videos_placeholders_list.append(video_placeholders)
|
||||||
|
grids.append(grid)
|
||||||
|
# assign all processed videos
|
||||||
|
if processed_videos:
|
||||||
|
visual_features[
|
||||||
|
"video_placeholders"] = videos_placeholders_list
|
||||||
|
output["video_pixel_values"] = processed_videos
|
||||||
|
output["video_grids"] = grids
|
||||||
|
|
||||||
|
# Process text input
|
||||||
|
if text is not None:
|
||||||
|
if not isinstance(text, list):
|
||||||
|
text = [text]
|
||||||
|
tokenized_batched_text = self._tokenize_with_visual_symbol(text)
|
||||||
|
image_token_id = self.get_token_value("image_token")
|
||||||
|
video_token_id = self.get_token_value("video_token")
|
||||||
|
replaced_ids_list = []
|
||||||
|
image_idx = 0
|
||||||
|
video_idx = 0
|
||||||
|
for ids_tensor in tokenized_batched_text:
|
||||||
|
has_image_tokens = (image_token_id in ids_tensor
|
||||||
|
and "image_placeholders" in visual_features
|
||||||
|
and image_idx < len(
|
||||||
|
visual_features["image_placeholders"]))
|
||||||
|
has_video_tokens = (video_token_id in ids_tensor
|
||||||
|
and "video_placeholders" in visual_features
|
||||||
|
and video_idx < len(
|
||||||
|
visual_features["video_placeholders"]))
|
||||||
|
if has_image_tokens or has_video_tokens:
|
||||||
|
# Convert to list for easier manipulation
|
||||||
|
ids_list = ids_tensor.tolist()
|
||||||
|
new_ids = []
|
||||||
|
|
||||||
|
# Replace placeholders
|
||||||
|
for token_id in ids_list:
|
||||||
|
if token_id == image_token_id:
|
||||||
|
new_ids.extend(
|
||||||
|
visual_features["image_placeholders"]
|
||||||
|
[image_idx])
|
||||||
|
image_idx += 1
|
||||||
|
elif token_id == video_token_id:
|
||||||
|
new_ids.extend(
|
||||||
|
visual_features["video_placeholders"]
|
||||||
|
[video_idx])
|
||||||
|
video_idx += 1
|
||||||
|
else:
|
||||||
|
new_ids.append(token_id)
|
||||||
|
# Convert back to tensor
|
||||||
|
ids_tensor = torch.tensor(new_ids, dtype=torch.long)
|
||||||
|
replaced_ids_list.append(ids_tensor)
|
||||||
|
if replaced_ids_list:
|
||||||
|
replaced_and_tokenized_ids = torch.stack(replaced_ids_list)
|
||||||
|
else:
|
||||||
|
replaced_and_tokenized_ids = torch.tensor([], dtype=torch.long)
|
||||||
|
output["input_ids"] = replaced_and_tokenized_ids
|
||||||
|
|
||||||
|
return output
|
||||||
|
# If only images were provided
|
||||||
|
return BatchFeature(data=visual_features)
|
||||||
|
|
||||||
|
def _tokenize_with_visual_symbol(self,
|
||||||
|
text_list: list[str]) -> torch.LongTensor:
|
||||||
|
batch_token_ids = []
|
||||||
|
for text in text_list:
|
||||||
|
token_ids = []
|
||||||
|
video_token_id = self.get_token_value("video_token")
|
||||||
|
image_token_id = self.get_token_value("image_token")
|
||||||
|
video_split_texts = text.split(self.video_token)
|
||||||
|
|
||||||
|
for j, video_segment in enumerate(video_split_texts):
|
||||||
|
image_split_texts = video_segment.split(self.image_token)
|
||||||
|
text_chunks = [
|
||||||
|
self.tokenizer(chunk, add_special_tokens=False).input_ids
|
||||||
|
for chunk in image_split_texts
|
||||||
|
]
|
||||||
|
segment_tokens = []
|
||||||
|
for i, chunk in enumerate(text_chunks):
|
||||||
|
segment_tokens.extend(chunk)
|
||||||
|
if i < len(text_chunks) - 1:
|
||||||
|
segment_tokens.append(image_token_id)
|
||||||
|
token_ids.extend(segment_tokens)
|
||||||
|
if j < len(video_split_texts) - 1:
|
||||||
|
token_ids.append(video_token_id)
|
||||||
|
|
||||||
|
batch_token_ids.append(token_ids)
|
||||||
|
return torch.tensor(batch_token_ids, dtype=torch.long)
|
||||||
|
|
||||||
|
# Copied from qwen2_vl
|
||||||
|
def smart_resize(self,
|
||||||
|
height: int,
|
||||||
|
width: int,
|
||||||
|
factor: int = 28,
|
||||||
|
min_pixels: int = MIN_PIXELS,
|
||||||
|
max_pixels: int = MAX_PIXELS):
|
||||||
|
"""Rescales the image so that the following conditions are met:
|
||||||
|
1. Both dimensions (height and width) are divisible by 'factor'.
|
||||||
|
2. The total number of pixels is within the range
|
||||||
|
['min_pixels', 'max_pixels'].
|
||||||
|
3. The aspect ratio of the image is maintained as closely as possible.
|
||||||
|
"""
|
||||||
|
if height < factor or width < factor:
|
||||||
|
print(f"height:{height} or width:{width} must be "
|
||||||
|
f"larger than factor:{factor}")
|
||||||
|
if height < width:
|
||||||
|
width = round(factor / height * width)
|
||||||
|
height = factor
|
||||||
|
else:
|
||||||
|
height = round(factor / width * height)
|
||||||
|
width = factor
|
||||||
|
|
||||||
|
elif max(height, width) / min(height, width) > 200:
|
||||||
|
print(f"absolute aspect ratio must be smaller than 200, "
|
||||||
|
f"got {max(height, width) / min(height, width)}")
|
||||||
|
if height > width:
|
||||||
|
height = 200 * width
|
||||||
|
else:
|
||||||
|
width = 200 * height
|
||||||
|
|
||||||
|
h_bar = round(height / factor) * factor
|
||||||
|
w_bar = round(width / factor) * factor
|
||||||
|
if h_bar * w_bar > max_pixels:
|
||||||
|
beta = math.sqrt((height * width) / max_pixels)
|
||||||
|
h_bar = math.floor(height / beta / factor) * factor
|
||||||
|
w_bar = math.floor(width / beta / factor) * factor
|
||||||
|
elif h_bar * w_bar < min_pixels:
|
||||||
|
beta = math.sqrt(min_pixels / (height * width))
|
||||||
|
h_bar = math.ceil(height * beta / factor) * factor
|
||||||
|
w_bar = math.ceil(width * beta / factor) * factor
|
||||||
|
return h_bar, w_bar
|
||||||
|
|
||||||
|
def get_token_value(self, tok):
|
||||||
|
return self.extra_special_tokens[tok]
|
||||||
|
|
||||||
|
def construct_visual_indicators(self, grid, is_video: bool = False):
|
||||||
|
if is_video:
|
||||||
|
start_token = self.get_token_value('video_start')
|
||||||
|
end_token = self.get_token_value('video_end')
|
||||||
|
else:
|
||||||
|
start_token = self.get_token_value('image_start')
|
||||||
|
end_token = self.get_token_value('image_end')
|
||||||
|
|
||||||
|
image_placeholders = [start_token, self.get_token_value('visual_atom')]
|
||||||
|
if grid[0] * grid[1] > 1:
|
||||||
|
for r in range(grid[0]):
|
||||||
|
for c in range(grid[1]):
|
||||||
|
image_placeholders.append(
|
||||||
|
self.get_token_value('visual_atom'))
|
||||||
|
|
||||||
|
image_placeholders.append(end_token)
|
||||||
|
return image_placeholders
|
||||||
|
|
||||||
|
def construct_visual_placeholders(self, grid, is_video: bool = False):
|
||||||
|
visual_placeholders = self.construct_visual_indicators((1, 1),
|
||||||
|
is_video)
|
||||||
|
|
||||||
|
image_atom_token_id = self.get_token_value('visual_atom')
|
||||||
|
# Extract the padding token ID from tokenizer
|
||||||
|
image_padding_token_id = self.get_token_value('image_pad')
|
||||||
|
|
||||||
|
num_image_atoms = grid[0] * grid[1] * grid[2]
|
||||||
|
num_image_atoms //= self.hidden_stride**2
|
||||||
|
num_image_atoms //= self.temporal_patch_size
|
||||||
|
|
||||||
|
# Create a new list with padding tokens inserted
|
||||||
|
padded_placeholder_tokens = []
|
||||||
|
for token in visual_placeholders:
|
||||||
|
if token == image_atom_token_id:
|
||||||
|
padded_placeholder_tokens.extend([image_padding_token_id] *
|
||||||
|
num_image_atoms)
|
||||||
|
else:
|
||||||
|
padded_placeholder_tokens.append(image_padding_token_id)
|
||||||
|
return padded_placeholder_tokens
|
||||||
|
|
||||||
|
def preprocess_multidata(
|
||||||
|
self,
|
||||||
|
images: Optional[Union[PIL.Image.Image, list[PIL.Image.Image]]] = None,
|
||||||
|
video: Optional[Union[list[PIL.Image.Image], np.ndarray]] = None,
|
||||||
|
convert_to_rgb: Optional[bool] = True,
|
||||||
|
min_pixels: int = MIN_PIXELS,
|
||||||
|
max_pixels: int = MAX_PIXELS,
|
||||||
|
return_tensors: Optional[str] = 'pt',
|
||||||
|
):
|
||||||
|
is_video = False
|
||||||
|
if images is not None:
|
||||||
|
if not isinstance(images, list):
|
||||||
|
images = [images]
|
||||||
|
elif video is not None:
|
||||||
|
is_video = True
|
||||||
|
# type of vidoe in dummy_mm_data is np.ndarray
|
||||||
|
if isinstance(video, np.ndarray):
|
||||||
|
images = []
|
||||||
|
for i in range(video.shape[0]):
|
||||||
|
image = PIL.Image.fromarray(video[i].astype(np.uint8))
|
||||||
|
images.append(image)
|
||||||
|
elif isinstance(video, list):
|
||||||
|
images = video
|
||||||
|
min_pixels = min(max_pixels if max_pixels is not None else MAX_PIXELS,
|
||||||
|
min_pixels if min_pixels is not None else MIN_PIXELS)
|
||||||
|
images = [
|
||||||
|
image.convert("RGB")
|
||||||
|
if convert_to_rgb and image.mode != 'RGB' else image
|
||||||
|
for image in images
|
||||||
|
]
|
||||||
|
|
||||||
|
width, height = images[0].size
|
||||||
|
resized_height, resized_width = height, width
|
||||||
|
processed_images = []
|
||||||
|
for image in images:
|
||||||
|
resized_height, resized_width = self.smart_resize(
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
factor=self.patch_size * self.hidden_stride,
|
||||||
|
min_pixels=min_pixels,
|
||||||
|
max_pixels=max_pixels,
|
||||||
|
)
|
||||||
|
new_size = dict(height=resized_height, width=resized_width)
|
||||||
|
image_pt = self.image_processor.preprocess(
|
||||||
|
image, size=new_size, return_tensors="np")['pixel_values'][0]
|
||||||
|
|
||||||
|
processed_images.append(image_pt)
|
||||||
|
|
||||||
|
patches = np.array(processed_images)
|
||||||
|
if patches.shape[0] % self.temporal_patch_size != 0:
|
||||||
|
num_to_pad = self.temporal_patch_size - (patches.shape[0] %
|
||||||
|
self.temporal_patch_size)
|
||||||
|
repeats = np.repeat(patches[-1][np.newaxis], num_to_pad, axis=0)
|
||||||
|
patches = np.concatenate([patches, repeats], axis=0)
|
||||||
|
channel = patches.shape[1]
|
||||||
|
grid_t = patches.shape[0] // self.temporal_patch_size
|
||||||
|
grid_h = resized_height // self.patch_size
|
||||||
|
grid_w = resized_width // self.patch_size
|
||||||
|
|
||||||
|
patches = patches.reshape(
|
||||||
|
grid_t,
|
||||||
|
self.temporal_patch_size,
|
||||||
|
channel,
|
||||||
|
grid_h // self.hidden_stride,
|
||||||
|
self.hidden_stride,
|
||||||
|
self.patch_size,
|
||||||
|
grid_w // self.hidden_stride,
|
||||||
|
self.hidden_stride,
|
||||||
|
self.patch_size,
|
||||||
|
)
|
||||||
|
patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)
|
||||||
|
flatten_patches = patches.reshape(
|
||||||
|
grid_t * grid_h * grid_w, channel * self.temporal_patch_size *
|
||||||
|
self.patch_size * self.patch_size)
|
||||||
|
|
||||||
|
visual_placeholders = self.construct_visual_placeholders(
|
||||||
|
[grid_t, grid_h, grid_w], is_video)
|
||||||
|
return torch.tensor(
|
||||||
|
flatten_patches), visual_placeholders, torch.tensor(
|
||||||
|
[[grid_t, grid_h, grid_w]])
|
||||||
|
|
||||||
|
|
||||||
|
AutoProcessor.register("Ovis2_5Processor", Ovis2_5Processor)
|
||||||
Loading…
x
Reference in New Issue
Block a user