mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 13:45:45 +08:00
[Model] Broadcast Ovis2 implementation to fit Ovis1.6 (#17861)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
7de18d541b
commit
021c16c7ca
@ -1045,10 +1045,10 @@ Specified using `--task generate`.
|
|||||||
*
|
*
|
||||||
* ✅︎
|
* ✅︎
|
||||||
* ✅︎
|
* ✅︎
|
||||||
- * `Ovis2ForConditionalGeneration`<sup>^</sup>
|
- * `Ovis`
|
||||||
* Ovis2
|
* Ovis2, Ovis1.6
|
||||||
* T + I<sup>+</sup>
|
* T + I<sup>+</sup>
|
||||||
* `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis2-2B`, etc.
|
* `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc.
|
||||||
*
|
*
|
||||||
*
|
*
|
||||||
* ✅︎
|
* ✅︎
|
||||||
|
|||||||
@ -725,8 +725,8 @@ def run_nvlm_d(questions: list[str], modality: str) -> ModelRequestData:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Ovis2
|
# Ovis
|
||||||
def run_ovis2(questions: list[str], modality: str) -> ModelRequestData:
|
def run_ovis(questions: list[str], modality: str) -> ModelRequestData:
|
||||||
assert modality == "image"
|
assert modality == "image"
|
||||||
|
|
||||||
model_name = "AIDC-AI/Ovis2-1B"
|
model_name = "AIDC-AI/Ovis2-1B"
|
||||||
@ -737,15 +737,18 @@ def run_ovis2(questions: list[str], modality: str) -> ModelRequestData:
|
|||||||
max_num_seqs=2,
|
max_num_seqs=2,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
dtype="half",
|
dtype="half",
|
||||||
hf_overrides={"architectures": ["Ovis2ForConditionalGeneration"]},
|
|
||||||
limit_mm_per_prompt={modality: 1},
|
limit_mm_per_prompt={modality: 1},
|
||||||
)
|
)
|
||||||
|
|
||||||
placeholder = "<image>\n"
|
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||||
prompts = [("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
trust_remote_code=True)
|
||||||
f"<|im_start|>user\n{placeholder}"
|
messages = [[{
|
||||||
f"{question}<|im_end|>\n"
|
'role': 'user',
|
||||||
"<|im_start|>assistant\n") for question in questions]
|
'content': f"<image>\n{question}"
|
||||||
|
}] for question in questions]
|
||||||
|
prompts = tokenizer.apply_chat_template(messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True)
|
||||||
|
|
||||||
return ModelRequestData(
|
return ModelRequestData(
|
||||||
engine_args=engine_args,
|
engine_args=engine_args,
|
||||||
@ -1069,7 +1072,7 @@ model_example_map = {
|
|||||||
"llama4": run_llama4,
|
"llama4": run_llama4,
|
||||||
"molmo": run_molmo,
|
"molmo": run_molmo,
|
||||||
"NVLM_D": run_nvlm_d,
|
"NVLM_D": run_nvlm_d,
|
||||||
"ovis2": run_ovis2,
|
"ovis": run_ovis,
|
||||||
"paligemma": run_paligemma,
|
"paligemma": run_paligemma,
|
||||||
"paligemma2": run_paligemma2,
|
"paligemma2": run_paligemma2,
|
||||||
"phi3_v": run_phi3v,
|
"phi3_v": run_phi3v,
|
||||||
|
|||||||
@ -436,8 +436,8 @@ def load_nvlm_d(question: str, image_urls: list[str]) -> ModelRequestData:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Ovis2
|
# Ovis
|
||||||
def load_ovis2(question: str, image_urls: list[str]) -> ModelRequestData:
|
def load_ovis(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||||
model_name = "AIDC-AI/Ovis2-1B"
|
model_name = "AIDC-AI/Ovis2-1B"
|
||||||
|
|
||||||
engine_args = EngineArgs(
|
engine_args = EngineArgs(
|
||||||
@ -447,15 +447,17 @@ def load_ovis2(question: str, image_urls: list[str]) -> ModelRequestData:
|
|||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
dtype="half",
|
dtype="half",
|
||||||
limit_mm_per_prompt={"image": len(image_urls)},
|
limit_mm_per_prompt={"image": len(image_urls)},
|
||||||
hf_overrides={"architectures": ["Ovis2ForConditionalGeneration"]},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
placeholder = '\n'.join(
|
placeholders = "\n".join(f"Image-{i}: <image>\n"
|
||||||
[f'Image {i+1}: <image>' for i in range(len(image_urls))]) + '\n'
|
for i, _ in enumerate(image_urls, start=1))
|
||||||
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
|
||||||
f"<|im_start|>user\n{placeholder}"
|
|
||||||
f"{question}<|im_end|>\n"
|
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||||
"<|im_start|>assistant\n")
|
trust_remote_code=True)
|
||||||
|
prompt = tokenizer.apply_chat_template(messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True)
|
||||||
|
|
||||||
return ModelRequestData(
|
return ModelRequestData(
|
||||||
engine_args=engine_args,
|
engine_args=engine_args,
|
||||||
@ -713,7 +715,7 @@ model_example_map = {
|
|||||||
"mistral3": load_mistral3,
|
"mistral3": load_mistral3,
|
||||||
"mllama": load_mllama,
|
"mllama": load_mllama,
|
||||||
"NVLM_D": load_nvlm_d,
|
"NVLM_D": load_nvlm_d,
|
||||||
"ovis2": load_ovis2,
|
"ovis": load_ovis,
|
||||||
"phi3_v": load_phi3v,
|
"phi3_v": load_phi3v,
|
||||||
"phi4_mm": load_phi4mm,
|
"phi4_mm": load_phi4mm,
|
||||||
"pixtral_hf": load_pixtral_hf,
|
"pixtral_hf": load_pixtral_hf,
|
||||||
|
|||||||
@ -355,10 +355,16 @@ class HfRunner:
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# in case some unquantized custom models are not in same dtype
|
||||||
|
if (getattr(model, "quantization_method", None) is None
|
||||||
|
and any(p.dtype != self.dtype
|
||||||
|
for p in model.parameters())):
|
||||||
|
model = model.to(dtype=self.dtype)
|
||||||
|
|
||||||
if (getattr(model, "quantization_method", None) != "bitsandbytes"
|
if (getattr(model, "quantization_method", None) != "bitsandbytes"
|
||||||
and len({p.device
|
and len({p.device
|
||||||
for p in model.parameters()}) < 2):
|
for p in model.parameters()}) < 2):
|
||||||
model = model.to(self.device)
|
model = model.to(device=self.device)
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
|
|||||||
@ -476,6 +476,31 @@ VLM_TEST_SETTINGS = {
|
|||||||
max_num_seqs=2,
|
max_num_seqs=2,
|
||||||
patch_hf_runner=model_utils.molmo_patch_hf_runner,
|
patch_hf_runner=model_utils.molmo_patch_hf_runner,
|
||||||
),
|
),
|
||||||
|
"ovis1_6-gemma2": VLMTestInfo(
|
||||||
|
models=["AIDC-AI/Ovis1.6-Gemma2-9B"],
|
||||||
|
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||||
|
prompt_formatter=lambda img_prompt: f"<bos><start_of_turn>user\n{img_prompt}<end_of_turn>\n<start_of_turn>model\n", # noqa: E501
|
||||||
|
img_idx_to_prompt=lambda idx: "<image>\n", # noqa: E501
|
||||||
|
max_model_len=4096,
|
||||||
|
max_num_seqs=2,
|
||||||
|
dtype="half",
|
||||||
|
# use sdpa mode for hf runner since ovis2 didn't work with flash_attn
|
||||||
|
hf_model_kwargs={"llm_attn_implementation": "sdpa"},
|
||||||
|
patch_hf_runner=model_utils.ovis_patch_hf_runner,
|
||||||
|
marks=[large_gpu_mark(min_gb=32)],
|
||||||
|
),
|
||||||
|
"ovis1_6": VLMTestInfo(
|
||||||
|
models=["AIDC-AI/Ovis1.6-Llama3.2-3B"],
|
||||||
|
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||||
|
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful and honest multimodal assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
|
||||||
|
img_idx_to_prompt=lambda idx: "<image>\n", # noqa: E501
|
||||||
|
max_model_len=4096,
|
||||||
|
max_num_seqs=2,
|
||||||
|
dtype="half",
|
||||||
|
# use sdpa mode for hf runner since ovis2 didn't work with flash_attn
|
||||||
|
hf_model_kwargs={"llm_attn_implementation": "sdpa"},
|
||||||
|
patch_hf_runner=model_utils.ovis_patch_hf_runner,
|
||||||
|
),
|
||||||
"ovis2": VLMTestInfo(
|
"ovis2": VLMTestInfo(
|
||||||
models=["AIDC-AI/Ovis2-1B"],
|
models=["AIDC-AI/Ovis2-1B"],
|
||||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||||
@ -486,7 +511,7 @@ VLM_TEST_SETTINGS = {
|
|||||||
dtype="half",
|
dtype="half",
|
||||||
# use sdpa mode for hf runner since ovis2 didn't work with flash_attn
|
# use sdpa mode for hf runner since ovis2 didn't work with flash_attn
|
||||||
hf_model_kwargs={"llm_attn_implementation": "sdpa"},
|
hf_model_kwargs={"llm_attn_implementation": "sdpa"},
|
||||||
patch_hf_runner=model_utils.ovis2_patch_hf_runner,
|
patch_hf_runner=model_utils.ovis_patch_hf_runner,
|
||||||
),
|
),
|
||||||
"phi3v": VLMTestInfo(
|
"phi3v": VLMTestInfo(
|
||||||
models=["microsoft/Phi-3.5-vision-instruct"],
|
models=["microsoft/Phi-3.5-vision-instruct"],
|
||||||
|
|||||||
@ -678,12 +678,8 @@ def molmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
|||||||
return hf_model
|
return hf_model
|
||||||
|
|
||||||
|
|
||||||
def ovis2_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
def ovis_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||||
"""Patches and returns an instance of the HfRunner to use for Ovis2."""
|
"""Patches and returns an instance of the HfRunner to use for Ovis2."""
|
||||||
hf_model.model.visual_tokenizer.to(hf_model.dtype)
|
|
||||||
hf_model.model.vte.to(hf_model.dtype)
|
|
||||||
hf_model.model.llm.to(hf_model.dtype)
|
|
||||||
|
|
||||||
hf_model.model.get_output_embeddings = lambda: \
|
hf_model.model.get_output_embeddings = lambda: \
|
||||||
hf_model.model.llm.get_output_embeddings()
|
hf_model.model.llm.get_output_embeddings()
|
||||||
|
|
||||||
@ -691,7 +687,16 @@ def ovis2_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
|||||||
text_tokenizer = hf_model.model.get_text_tokenizer()
|
text_tokenizer = hf_model.model.get_text_tokenizer()
|
||||||
images = [images] if isinstance(images, Image) else images
|
images = [images] if isinstance(images, Image) else images
|
||||||
|
|
||||||
text = text.split("<|im_start|>user\n")[1].split("<|im_end|>\n")[0]
|
prompt_start_and_end = {
|
||||||
|
"qwen2": ("<|im_start|>user\n", "<|im_end|>\n"),
|
||||||
|
"llama":
|
||||||
|
("<|start_header_id|>user<|end_header_id|>\n\n", "<|eot_id|>"),
|
||||||
|
"gemma2": ("<start_of_turn>user\n", "<end_of_turn>\n"),
|
||||||
|
}
|
||||||
|
for start, end in prompt_start_and_end.values():
|
||||||
|
if start in text and end in text:
|
||||||
|
text = text.split(start)[1].split(end)[0]
|
||||||
|
break
|
||||||
|
|
||||||
prompt, input_ids, pixel_values = hf_model.model.preprocess_inputs(
|
prompt, input_ids, pixel_values = hf_model.model.preprocess_inputs(
|
||||||
text_or_conversations=text, images=images)
|
text_or_conversations=text, images=images)
|
||||||
|
|||||||
@ -146,7 +146,8 @@ def _test_processing_correctness_hf(
|
|||||||
batch_idx: int,
|
batch_idx: int,
|
||||||
ignore_mm_keys: Optional[set[str]] = None,
|
ignore_mm_keys: Optional[set[str]] = None,
|
||||||
):
|
):
|
||||||
if model_config.hf_config.model_type in ("mllama", "whisper", "ultravox"):
|
if model_config.hf_config.model_type in ("mllama", "ovis", "ultravox",
|
||||||
|
"whisper"):
|
||||||
# For some multimodal models, tokenizer will always add bos_token
|
# For some multimodal models, tokenizer will always add bos_token
|
||||||
# at the beginning of prompt by default, causing hf_processor outputs
|
# at the beginning of prompt by default, causing hf_processor outputs
|
||||||
# incorrect token ids. So we need use `add_special_tokens=False` here
|
# incorrect token ids. So we need use `add_special_tokens=False` here
|
||||||
@ -274,6 +275,8 @@ def _test_processing_correctness_mistral(
|
|||||||
"allenai/Molmo-7B-D-0924",
|
"allenai/Molmo-7B-D-0924",
|
||||||
"allenai/Molmo-7B-O-0924",
|
"allenai/Molmo-7B-O-0924",
|
||||||
"nvidia/NVLM-D-72B",
|
"nvidia/NVLM-D-72B",
|
||||||
|
"AIDC-AI/Ovis1.6-Gemma2-9B",
|
||||||
|
"AIDC-AI/Ovis1.6-Llama3.2-3B",
|
||||||
"AIDC-AI/Ovis2-1B",
|
"AIDC-AI/Ovis2-1B",
|
||||||
"google/paligemma-3b-mix-224",
|
"google/paligemma-3b-mix-224",
|
||||||
"google/paligemma2-3b-ft-docci-448",
|
"google/paligemma2-3b-ft-docci-448",
|
||||||
|
|||||||
@ -355,9 +355,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
|||||||
max_transformers_version="4.48",
|
max_transformers_version="4.48",
|
||||||
transformers_version_reason="Use of deprecated imports which have been removed.", # noqa: E501
|
transformers_version_reason="Use of deprecated imports which have been removed.", # noqa: E501
|
||||||
extras={"phi3.5": "microsoft/Phi-3.5-vision-instruct"}), # noqa: E501
|
extras={"phi3.5": "microsoft/Phi-3.5-vision-instruct"}), # noqa: E501
|
||||||
"Ovis2ForConditionalGeneration": _HfExamplesInfo("AIDC-AI/Ovis2-1B",
|
"Ovis": _HfExamplesInfo("AIDC-AI/Ovis2-1B", trust_remote_code=True,
|
||||||
trust_remote_code=True,
|
extras={"1.6-llama": "AIDC-AI/Ovis1.6-Llama3.2-3B",
|
||||||
hf_overrides={"architectures": ["Ovis2ForConditionalGeneration"]}), # noqa: E501
|
"1.6-gemma": "AIDC-AI/Ovis1.6-Gemma2-9B"}), # noqa: E501
|
||||||
"Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct",
|
"Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct",
|
||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
"PixtralForConditionalGeneration": _HfExamplesInfo("mistralai/Pixtral-12B-2409", # noqa: E501
|
"PixtralForConditionalGeneration": _HfExamplesInfo("mistralai/Pixtral-12B-2409", # noqa: E501
|
||||||
|
|||||||
@ -512,7 +512,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
|||||||
hf_config.image_token_index)
|
hf_config.image_token_index)
|
||||||
|
|
||||||
if model_type in ("aya_vision", "chameleon", "deepseek_vl_v2",
|
if model_type in ("aya_vision", "chameleon", "deepseek_vl_v2",
|
||||||
"internvl_chat", "ovis2", "skywork_chat",
|
"internvl_chat", "ovis", "skywork_chat",
|
||||||
"NVLM_D", "h2ovl_chat", "idefics3", "smolvlm"):
|
"NVLM_D", "h2ovl_chat", "idefics3", "smolvlm"):
|
||||||
return "<image>"
|
return "<image>"
|
||||||
if model_type in ("mllama", "llama4"):
|
if model_type in ("mllama", "llama4"):
|
||||||
|
|||||||
@ -5,129 +5,14 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn, softmax
|
import torch.nn as nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torch.nn.functional import gumbel_softmax, pad
|
|
||||||
|
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.transformers_utils.configs.ovis2 import (AIMv2Config,
|
from vllm.transformers_utils.configs.ovis import AIMv2Config
|
||||||
Aimv2VisualTokenizerConfig)
|
|
||||||
|
|
||||||
IMAGE_INDICATOR_IDS = [-301, -302, -303, -304,
|
|
||||||
-305] # kept for vocab prefixed tokens
|
|
||||||
|
|
||||||
|
|
||||||
def st_argmax(y_soft: torch.Tensor, dim: int): # straight-through softmax
|
|
||||||
index = y_soft.max(dim, keepdim=True)[1]
|
|
||||||
y_hard = torch.zeros_like(
|
|
||||||
y_soft, memory_format=torch.legacy_contiguous_format).scatter_(
|
|
||||||
dim, index, 1.0)
|
|
||||||
ret = y_hard - y_soft.detach() + y_soft
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
class Aimv2VisualTokenizer(torch.nn.Module):
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
config: Aimv2VisualTokenizerConfig,
|
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
|
||||||
prefix: str = "",
|
|
||||||
**kwargs):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.backbone = AIMv2Model(
|
|
||||||
config=config.backbone_config, # noqa
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=f"{prefix}.visual_tokenizer")
|
|
||||||
# reserved tokens for IMAGE_INDICATORS
|
|
||||||
head_dim = config.vocab_size - len(IMAGE_INDICATOR_IDS)
|
|
||||||
self.head = torch.nn.Sequential(
|
|
||||||
ReplicatedLinear(
|
|
||||||
config.backbone_config.hidden_size * config.hidden_stride *
|
|
||||||
config.hidden_stride,
|
|
||||||
head_dim,
|
|
||||||
bias=False,
|
|
||||||
), torch.nn.LayerNorm(head_dim))
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dtype(self):
|
|
||||||
return self.backbone.dtype
|
|
||||||
|
|
||||||
@property
|
|
||||||
def device(self):
|
|
||||||
return self.backbone.device
|
|
||||||
|
|
||||||
def tokenize(self, logits):
|
|
||||||
if self.config.tokenize_function == 'softmax':
|
|
||||||
tokens = softmax(logits, dim=-1)
|
|
||||||
elif self.config.tokenize_function == 'gumbel_argmax':
|
|
||||||
tokens = gumbel_softmax(logits, tau=self.config.tau, hard=True)
|
|
||||||
elif self.config.tokenize_function == 'st_argmax':
|
|
||||||
tokens = st_argmax(logits, dim=-1)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
'Invalid `max_type`, expected softmax or gumbel_argmax '
|
|
||||||
f'or st_argmax, but got {self.config.tokenize_function}')
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
def encode(self, pixel_values):
|
|
||||||
features = self.backbone(pixel_values)
|
|
||||||
if self.config.drop_cls_token:
|
|
||||||
features = features[:, 1:, :]
|
|
||||||
|
|
||||||
# merge number of `hidden_stride * hidden_stride` hidden states together
|
|
||||||
# to reduce token sequence length
|
|
||||||
# e.g., for hidden_stride=2, this leads to a token length reduction:
|
|
||||||
# 1024 -> 256 for aimv2
|
|
||||||
if self.config.hidden_stride > 1:
|
|
||||||
# this `d` maybe different from the above `d``
|
|
||||||
n, L, d = features.shape
|
|
||||||
sqrt_l = int(L**0.5)
|
|
||||||
assert sqrt_l**2 == L, (
|
|
||||||
"The token sequence length should be a perfect square.")
|
|
||||||
features = features.reshape(n, sqrt_l, sqrt_l, d)
|
|
||||||
pl = (self.config.hidden_stride -
|
|
||||||
(sqrt_l %
|
|
||||||
self.config.hidden_stride)) % self.config.hidden_stride
|
|
||||||
features = pad(features, (0, 0, 0, pl, 0, pl), "constant", 0)
|
|
||||||
sqrt_l += pl
|
|
||||||
features = features.reshape(n, sqrt_l // self.config.hidden_stride,
|
|
||||||
self.config.hidden_stride,
|
|
||||||
sqrt_l // self.config.hidden_stride,
|
|
||||||
self.config.hidden_stride, d)
|
|
||||||
# [n, sqrt_l/hs, sqrt_l/hs, hs, hs, d]
|
|
||||||
features = features.permute(0, 1, 3, 2, 4, 5)
|
|
||||||
# [n, sqrt_l/hs, sqrt_l/hs, hs*hs*d]
|
|
||||||
features = features.flatten(3)
|
|
||||||
# [n, sqrt_l/hs*sqrt_l/hs, hs*hs*d]
|
|
||||||
features = features.reshape(
|
|
||||||
n, -1,
|
|
||||||
self.config.hidden_stride * self.config.hidden_stride * d)
|
|
||||||
|
|
||||||
return features
|
|
||||||
|
|
||||||
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""[BatchSize, ImageShape] -> [BatchSize, Token, VocabSize]"""
|
|
||||||
features = self.encode(pixel_values)
|
|
||||||
logits, _ = self.head[0](
|
|
||||||
features) # we spllit the sequncial here for not throwing an error
|
|
||||||
logits = self.head[1](logits)
|
|
||||||
tokens = self.tokenize(logits)
|
|
||||||
# tokens' shape is [BatchSize, #Token, VocabSize-5], so padding with
|
|
||||||
# [BatchSize, #Token, 5], after which, tokens' shape should become
|
|
||||||
# [BatchSize, #Token, VocabSize]
|
|
||||||
batch_size, token_len, _ = tokens.shape
|
|
||||||
padding_tensor = torch.zeros(size=(batch_size, token_len,
|
|
||||||
len(IMAGE_INDICATOR_IDS)),
|
|
||||||
dtype=tokens.dtype,
|
|
||||||
device=tokens.device,
|
|
||||||
layout=tokens.layout,
|
|
||||||
requires_grad=False)
|
|
||||||
tokens = torch.cat((tokens, padding_tensor), dim=2)
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
|
|
||||||
class AIMv2SwiGLUFFN(nn.Module):
|
class AIMv2SwiGLUFFN(nn.Module):
|
||||||
@ -302,14 +187,6 @@ class AIMv2Model(torch.nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.trunk")
|
prefix=f"{prefix}.trunk")
|
||||||
|
|
||||||
@property
|
|
||||||
def dtype(self):
|
|
||||||
return self.trunk.blocks[0].attn.qkv.weight.dtype
|
|
||||||
|
|
||||||
@property
|
|
||||||
def device(self):
|
|
||||||
return self.trunk.blocks[0].attn.qkv.device
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
pixel_values: torch.Tensor,
|
pixel_values: torch.Tensor,
|
||||||
|
|||||||
@ -15,17 +15,23 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" PyTorch Ovis2 model."""
|
""" PyTorch Ovis model."""
|
||||||
|
import math
|
||||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||||
TypedDict, Union)
|
TypedDict, Union)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from transformers import BatchFeature
|
from torch.nn.functional import gumbel_softmax, pad, softmax
|
||||||
|
from transformers import BaseImageProcessor, BatchFeature
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.models.aimv2 import Aimv2VisualTokenizer
|
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig)
|
||||||
|
from vllm.model_executor.models.aimv2 import AIMv2Model
|
||||||
|
from vllm.model_executor.models.siglip import SiglipVisionModel
|
||||||
from vllm.model_executor.models.utils import (AutoWeightsLoader, flatten_bn,
|
from vllm.model_executor.models.utils import (AutoWeightsLoader, flatten_bn,
|
||||||
init_vllm_registered_model,
|
init_vllm_registered_model,
|
||||||
maybe_prefix)
|
maybe_prefix)
|
||||||
@ -38,19 +44,160 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
BaseProcessingInfo, PromptReplacement)
|
BaseProcessingInfo, PromptReplacement)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.transformers_utils.configs.ovis2 import OvisConfig
|
from vllm.transformers_utils.configs.ovis import (BaseVisualTokenizerConfig,
|
||||||
from vllm.transformers_utils.processors.ovis2 import OvisProcessor
|
OvisConfig)
|
||||||
|
from vllm.transformers_utils.processors.ovis import OvisProcessor
|
||||||
|
|
||||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal
|
||||||
from .utils import merge_multimodal_embeddings
|
from .utils import merge_multimodal_embeddings
|
||||||
|
|
||||||
# Cannot find the following number from hf config.
|
# Cannot find the following number from hf config.
|
||||||
IMAGE_TOKEN = "<image>"
|
IMAGE_TOKEN = "<image>"
|
||||||
IMAGE_PAD_TOKEN_ID = 151655
|
IMAGE_INDICATOR_IDS = [-301, -302, -303, -304, -305]
|
||||||
NUMBER_OF_TOKEN_TO_RESERVE_FOR_SEGMENT = 256
|
|
||||||
|
IMAGE_PAD_TOKEN_MAP = {
|
||||||
|
"gemma2": "<unused0>",
|
||||||
|
"llama": "<|reserved_special_token_0|>",
|
||||||
|
"qwen2": "<|image_pad|>",
|
||||||
|
}
|
||||||
|
IMAGE_PAD_TOKEN_ID_MAP = {
|
||||||
|
"gemma2": 7,
|
||||||
|
"llama": 128002,
|
||||||
|
"qwen2": 151655,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class Ovis2ImagePatchInputs(TypedDict):
|
def st_argmax(y_soft: torch.Tensor, dim: int): # straight-through softmax
|
||||||
|
index = y_soft.argmax(dim, keepdim=True)
|
||||||
|
return torch.zeros_like(
|
||||||
|
y_soft,
|
||||||
|
memory_format=torch.legacy_contiguous_format,
|
||||||
|
).scatter_(dim, index, 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
class VisualTokenizer(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: BaseVisualTokenizerConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.backbone = self._init_backbone(
|
||||||
|
config=config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.backbone",
|
||||||
|
)
|
||||||
|
# reserved tokens for IMAGE_INDICATORS
|
||||||
|
head_dim = config.vocab_size - len(IMAGE_INDICATOR_IDS)
|
||||||
|
self.head = torch.nn.Sequential(
|
||||||
|
ReplicatedLinear(
|
||||||
|
config.backbone_config.hidden_size * config.hidden_stride *
|
||||||
|
config.hidden_stride,
|
||||||
|
head_dim,
|
||||||
|
bias=False,
|
||||||
|
return_bias=False,
|
||||||
|
), torch.nn.LayerNorm(head_dim))
|
||||||
|
|
||||||
|
def _init_backbone(
|
||||||
|
self,
|
||||||
|
config: BaseVisualTokenizerConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
|
model_type = config.backbone_config.model_type
|
||||||
|
if model_type == "aimv2":
|
||||||
|
return AIMv2Model(
|
||||||
|
config=config.backbone_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=prefix,
|
||||||
|
)
|
||||||
|
elif model_type == "siglip_vision_model":
|
||||||
|
return SiglipVisionModel(
|
||||||
|
config=config.backbone_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=prefix,
|
||||||
|
)
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported visual tokenizer model_type: {model_type}")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
return next(self.head.parameters()).dtype
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return next(self.head.parameters()).device
|
||||||
|
|
||||||
|
def tokenize(self, logits):
|
||||||
|
if self.config.tokenize_function == 'softmax':
|
||||||
|
tokens = softmax(logits, dim=-1)
|
||||||
|
elif self.config.tokenize_function == 'gumbel_argmax':
|
||||||
|
tokens = gumbel_softmax(logits, tau=self.config.tau, hard=True)
|
||||||
|
elif self.config.tokenize_function == 'st_argmax':
|
||||||
|
tokens = st_argmax(logits, dim=-1)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
'Invalid `max_type`, expected softmax or gumbel_argmax '
|
||||||
|
f'or st_argmax, but got {self.config.tokenize_function}')
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def encode(self, pixel_values):
|
||||||
|
features = self.backbone(pixel_values)
|
||||||
|
if self.config.drop_cls_token:
|
||||||
|
features = features[:, 1:, :]
|
||||||
|
|
||||||
|
# merge number of `hidden_stride * hidden_stride` hidden states together
|
||||||
|
# to reduce token sequence length
|
||||||
|
# e.g., for hidden_stride=2, this leads to a token length reduction:
|
||||||
|
# 1024 -> 256 for aimv2
|
||||||
|
if self.config.hidden_stride > 1:
|
||||||
|
# this `d` maybe different from the above `d``
|
||||||
|
n, L, d = features.shape
|
||||||
|
sqrt_l = int(L**0.5)
|
||||||
|
assert sqrt_l**2 == L, (
|
||||||
|
"The token sequence length should be a perfect square.")
|
||||||
|
features = features.reshape(n, sqrt_l, sqrt_l, d)
|
||||||
|
pl = (self.config.hidden_stride -
|
||||||
|
(sqrt_l %
|
||||||
|
self.config.hidden_stride)) % self.config.hidden_stride
|
||||||
|
features = pad(features, (0, 0, 0, pl, 0, pl), "constant", 0)
|
||||||
|
sqrt_l += pl
|
||||||
|
features = features.reshape(n, sqrt_l // self.config.hidden_stride,
|
||||||
|
self.config.hidden_stride,
|
||||||
|
sqrt_l // self.config.hidden_stride,
|
||||||
|
self.config.hidden_stride, d)
|
||||||
|
# [n, sqrt_l/hs, sqrt_l/hs, hs, hs, d]
|
||||||
|
features = features.permute(0, 1, 3, 2, 4, 5)
|
||||||
|
# [n, sqrt_l/hs, sqrt_l/hs, hs*hs*d]
|
||||||
|
features = features.flatten(3)
|
||||||
|
# [n, sqrt_l/hs*sqrt_l/hs, hs*hs*d]
|
||||||
|
features = features.reshape(
|
||||||
|
n, -1,
|
||||||
|
self.config.hidden_stride * self.config.hidden_stride * d)
|
||||||
|
|
||||||
|
return features
|
||||||
|
|
||||||
|
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""[BatchSize, ImageShape] -> [BatchSize, Token, VocabSize]"""
|
||||||
|
features = self.encode(pixel_values)
|
||||||
|
logits = self.head(features)
|
||||||
|
tokens = self.tokenize(logits)
|
||||||
|
# tokens' shape is [BatchSize, #Token, VocabSize-5], so padding with
|
||||||
|
# [BatchSize, #Token, 5], after which, tokens' shape should become
|
||||||
|
# [BatchSize, #Token, VocabSize]
|
||||||
|
tokens = torch.nn.functional.pad(
|
||||||
|
tokens,
|
||||||
|
(0, len(IMAGE_INDICATOR_IDS)),
|
||||||
|
mode="constant",
|
||||||
|
value=0,
|
||||||
|
)
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
class OvisImagePatchInputs(TypedDict):
|
||||||
type: Literal["image_patches"]
|
type: Literal["image_patches"]
|
||||||
flat_data: torch.Tensor
|
flat_data: torch.Tensor
|
||||||
"""
|
"""
|
||||||
@ -92,31 +239,50 @@ class VisualEmbedding(torch.nn.Embedding):
|
|||||||
return self.weight.dtype
|
return self.weight.dtype
|
||||||
|
|
||||||
|
|
||||||
class Ovis2ProcessingInfo(BaseProcessingInfo):
|
class OvisProcessingInfo(BaseProcessingInfo):
|
||||||
|
|
||||||
def get_hf_config(self):
|
def get_hf_config(self):
|
||||||
return self.ctx.get_hf_config(OvisConfig)
|
return self.ctx.get_hf_config(OvisConfig)
|
||||||
|
|
||||||
def get_hf_processor(self, **kwargs):
|
def get_hf_processor(self, **kwargs):
|
||||||
return self.ctx.get_hf_processor(OvisProcessor)
|
return self.ctx.get_hf_processor(
|
||||||
|
OvisProcessor,
|
||||||
|
image_pad_token=self.get_image_pad_token(),
|
||||||
|
image_segment_len=self.get_image_segment_len(),
|
||||||
|
)
|
||||||
|
|
||||||
def get_image_processor(self) -> OvisProcessor:
|
def get_image_segment_len(self) -> int:
|
||||||
|
visual_tokenizer_config = self.get_hf_config().visual_tokenizer_config
|
||||||
|
image_size = visual_tokenizer_config.backbone_config.image_size
|
||||||
|
patch_size = visual_tokenizer_config.backbone_config.patch_size
|
||||||
|
hidden_stride = visual_tokenizer_config.hidden_stride
|
||||||
|
patch_grid_length = math.ceil(image_size / patch_size)
|
||||||
|
assert patch_grid_length % hidden_stride == 0, (
|
||||||
|
f"patch_grid_length {patch_grid_length} is not divisible by "
|
||||||
|
f"hidden_stride {hidden_stride}")
|
||||||
|
# minus 1 for presented image token
|
||||||
|
return (patch_grid_length // hidden_stride)**2 - 1
|
||||||
|
|
||||||
|
def get_image_pad_token(self) -> str:
|
||||||
|
hf_text_config = self.get_hf_config().get_text_config()
|
||||||
|
text_model_type = hf_text_config.model_type
|
||||||
|
return IMAGE_PAD_TOKEN_MAP.get(text_model_type)
|
||||||
|
|
||||||
|
def get_image_processor(self) -> BaseImageProcessor:
|
||||||
return self.get_hf_processor().image_processor # type: ignore
|
return self.get_hf_processor().image_processor # type: ignore
|
||||||
|
|
||||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||||
return { # 32k is model token limit at the moment
|
return {"image": None}
|
||||||
"image":
|
|
||||||
self.get_hf_config().multimodal_max_length //
|
|
||||||
((9 + 1) * NUMBER_OF_TOKEN_TO_RESERVE_FOR_SEGMENT)
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_image_size_with_most_features(self) -> ImageSize:
|
def get_image_size_with_most_features(self) -> ImageSize:
|
||||||
image_processor = self.get_image_processor()
|
height, width = self.get_hf_processor().get_image_size()
|
||||||
return ImageSize(width=image_processor.size['shortest_edge'] * 9 * 2,
|
hs = self.get_hf_config().visual_tokenizer_config.hidden_stride
|
||||||
height=image_processor.size['shortest_edge'] * 9 * 2)
|
# NOTE(Isotr0py): 9 is `max_partion` hardcoded in original code
|
||||||
|
# https://huggingface.co/AIDC-AI/Ovis2-1B/blob/main/modeling_ovis.py#L96
|
||||||
|
return ImageSize(width=width * hs * 9, height=height * hs * 9)
|
||||||
|
|
||||||
|
|
||||||
class Ovis2DummyInputsBuilder(BaseDummyInputsBuilder[Ovis2ProcessingInfo]):
|
class OvisDummyInputsBuilder(BaseDummyInputsBuilder[OvisProcessingInfo]):
|
||||||
|
|
||||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||||
num_images = mm_counts.get("image", 0)
|
num_images = mm_counts.get("image", 0)
|
||||||
@ -141,7 +307,7 @@ class Ovis2DummyInputsBuilder(BaseDummyInputsBuilder[Ovis2ProcessingInfo]):
|
|||||||
return mm_data
|
return mm_data
|
||||||
|
|
||||||
|
|
||||||
class Ovis2MultiModalProcessor(BaseMultiModalProcessor[Ovis2ProcessingInfo]):
|
class OvisMultiModalProcessor(BaseMultiModalProcessor[OvisProcessingInfo]):
|
||||||
|
|
||||||
def image_indicators_to_visual_tokens(
|
def image_indicators_to_visual_tokens(
|
||||||
self,
|
self,
|
||||||
@ -165,9 +331,9 @@ class Ovis2MultiModalProcessor(BaseMultiModalProcessor[Ovis2ProcessingInfo]):
|
|||||||
mm_kwargs: Mapping[str, object],
|
mm_kwargs: Mapping[str, object],
|
||||||
) -> BatchFeature:
|
) -> BatchFeature:
|
||||||
if not mm_data:
|
if not mm_data:
|
||||||
# # Avoid warning from HF logger for text-only input
|
# Avoid warning from HF logger for text-only input
|
||||||
prompt_ids = self.info.get_tokenizer().encode(prompt)
|
tokenizer = self.info.get_tokenizer()
|
||||||
# prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) nope
|
prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
|
||||||
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
|
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
|
||||||
|
|
||||||
processed_outputs = super()._call_hf_processor(
|
processed_outputs = super()._call_hf_processor(
|
||||||
@ -226,10 +392,10 @@ class Ovis2MultiModalProcessor(BaseMultiModalProcessor[Ovis2ProcessingInfo]):
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_processor(Ovis2MultiModalProcessor,
|
@MULTIMODAL_REGISTRY.register_processor(OvisMultiModalProcessor,
|
||||||
info=Ovis2ProcessingInfo,
|
info=OvisProcessingInfo,
|
||||||
dummy_inputs=Ovis2DummyInputsBuilder)
|
dummy_inputs=OvisDummyInputsBuilder)
|
||||||
class Ovis2ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
class Ovis(nn.Module, SupportsMultiModal):
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -242,24 +408,25 @@ class Ovis2ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
prefix=maybe_prefix(prefix, "llm"),
|
prefix=maybe_prefix(prefix, "llm"),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.visual_tokenizer = Aimv2VisualTokenizer(
|
self.visual_tokenizer = VisualTokenizer(
|
||||||
config=config.visual_tokenizer_config,
|
config=config.visual_tokenizer_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.visual_tokenizer",
|
prefix=f"{prefix}.visual_tokenizer",
|
||||||
image_processor_name_or_path=config.visual_tokenizer_config.
|
|
||||||
backbone_config.name_or_path,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.vte = VisualEmbedding(
|
self.vte = VisualEmbedding(
|
||||||
self.config.visual_tokenizer_config.vocab_size,
|
self.config.visual_tokenizer_config.vocab_size,
|
||||||
self.config.hidden_size)
|
self.config.hidden_size)
|
||||||
|
|
||||||
|
text_model_type = self.config.get_text_config().model_type
|
||||||
|
self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type]
|
||||||
|
|
||||||
# TODO(Isotr0py): PP support
|
# TODO(Isotr0py): PP support
|
||||||
# self.make_empty_intermediate_tensors = (
|
# self.make_empty_intermediate_tensors = (
|
||||||
# self.language_model.make_empty_intermediate_tensors)
|
# self.language_model.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
def _parse_and_validate_image_input(
|
def _parse_and_validate_image_input(
|
||||||
self, **kwargs: object) -> Optional[Ovis2ImagePatchInputs]:
|
self, **kwargs: object) -> Optional[OvisImagePatchInputs]:
|
||||||
pixel_values = kwargs.pop("pixel_values", None)
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
indicator_tokens = kwargs.pop("indicator_tokens", None)
|
indicator_tokens = kwargs.pop("indicator_tokens", None)
|
||||||
|
|
||||||
@ -275,7 +442,7 @@ class Ovis2ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
raise ValueError("Incorrect type of indicator_tokens. "
|
raise ValueError("Incorrect type of indicator_tokens. "
|
||||||
f"Got type: {type(pixel_values)}")
|
f"Got type: {type(pixel_values)}")
|
||||||
|
|
||||||
return Ovis2ImagePatchInputs(
|
return OvisImagePatchInputs(
|
||||||
type="image_patches",
|
type="image_patches",
|
||||||
flat_data=flatten_bn(flatten_bn(pixel_values), concat=True),
|
flat_data=flatten_bn(flatten_bn(pixel_values), concat=True),
|
||||||
patches_per_image=[
|
patches_per_image=[
|
||||||
@ -288,7 +455,7 @@ class Ovis2ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
raise AssertionError("This line should be unreachable.")
|
raise AssertionError("This line should be unreachable.")
|
||||||
|
|
||||||
def _process_image_input(
|
def _process_image_input(
|
||||||
self, image_input: Ovis2ImagePatchInputs) -> MultiModalEmbeddings:
|
self, image_input: OvisImagePatchInputs) -> MultiModalEmbeddings:
|
||||||
image_patches_flat = image_input["flat_data"]
|
image_patches_flat = image_input["flat_data"]
|
||||||
patches_per_image = image_input["patches_per_image"]
|
patches_per_image = image_input["patches_per_image"]
|
||||||
indicator_tokens = image_input["indicator_tokens"]
|
indicator_tokens = image_input["indicator_tokens"]
|
||||||
@ -338,7 +505,7 @@ class Ovis2ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
if multimodal_embeddings is not None:
|
if multimodal_embeddings is not None:
|
||||||
inputs_embeds = merge_multimodal_embeddings(
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
input_ids, inputs_embeds, multimodal_embeddings,
|
input_ids, inputs_embeds, multimodal_embeddings,
|
||||||
[IMAGE_PAD_TOKEN_ID])
|
self.image_pad_token_id)
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -375,8 +542,7 @@ class Ovis2ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Optional[torch.Tensor]:
|
) -> Optional[torch.Tensor]:
|
||||||
logits = self.llm.logits_processor(self.llm.lm_head, hidden_states,
|
logits = self.llm.compute_logits(hidden_states, sampling_metadata)
|
||||||
sampling_metadata)
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str,
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
@ -195,7 +195,7 @@ _MULTIMODAL_MODELS = {
|
|||||||
"Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"), # noqa: E501
|
"Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"), # noqa: E501
|
||||||
"MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
|
"MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
|
||||||
"NVLM_D": ("nvlm_d", "NVLM_D_Model"),
|
"NVLM_D": ("nvlm_d", "NVLM_D_Model"),
|
||||||
"Ovis2ForConditionalGeneration": ("ovis2", "Ovis2ForConditionalGeneration"),
|
"Ovis": ("ovis", "Ovis"),
|
||||||
"PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501
|
"PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501
|
||||||
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
||||||
"PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501
|
"PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501
|
||||||
|
|||||||
@ -23,7 +23,7 @@ from vllm.transformers_utils.configs.moonvit import MoonViTConfig
|
|||||||
from vllm.transformers_utils.configs.mpt import MPTConfig
|
from vllm.transformers_utils.configs.mpt import MPTConfig
|
||||||
from vllm.transformers_utils.configs.nemotron import NemotronConfig
|
from vllm.transformers_utils.configs.nemotron import NemotronConfig
|
||||||
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
|
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
|
||||||
from vllm.transformers_utils.configs.ovis2 import OvisConfig
|
from vllm.transformers_utils.configs.ovis import OvisConfig
|
||||||
from vllm.transformers_utils.configs.skyworkr1v import SkyworkR1VChatConfig
|
from vllm.transformers_utils.configs.skyworkr1v import SkyworkR1VChatConfig
|
||||||
from vllm.transformers_utils.configs.solar import SolarConfig
|
from vllm.transformers_utils.configs.solar import SolarConfig
|
||||||
from vllm.transformers_utils.configs.telechat2 import Telechat2Config
|
from vllm.transformers_utils.configs.telechat2 import Telechat2Config
|
||||||
|
|||||||
@ -123,6 +123,19 @@ class Aimv2VisualTokenizerConfig(BaseVisualTokenizerConfig):
|
|||||||
self.backbone_kwargs['num_hidden_layers'] = self.depths[0]
|
self.backbone_kwargs['num_hidden_layers'] = self.depths[0]
|
||||||
|
|
||||||
|
|
||||||
|
class SiglipVisualTokenizerConfig(BaseVisualTokenizerConfig):
|
||||||
|
model_type = "siglip_visual_tokenizer"
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
if self.drop_cls_token:
|
||||||
|
self.drop_cls_token = False
|
||||||
|
if self.depths:
|
||||||
|
assert len(self.depths) == 1
|
||||||
|
self.backbone_kwargs['num_hidden_layers'] = self.depths[0]
|
||||||
|
|
||||||
|
|
||||||
|
AutoConfig.register("siglip_visual_tokenizer", SiglipVisualTokenizerConfig)
|
||||||
AutoConfig.register("aimv2_visual_tokenizer", Aimv2VisualTokenizerConfig)
|
AutoConfig.register("aimv2_visual_tokenizer", Aimv2VisualTokenizerConfig)
|
||||||
|
|
||||||
|
|
||||||
@ -2,6 +2,6 @@
|
|||||||
|
|
||||||
from vllm.transformers_utils.processors.deepseek_vl2 import (
|
from vllm.transformers_utils.processors.deepseek_vl2 import (
|
||||||
DeepseekVLV2Processor)
|
DeepseekVLV2Processor)
|
||||||
from vllm.transformers_utils.processors.ovis2 import OvisProcessor
|
from vllm.transformers_utils.processors.ovis import OvisProcessor
|
||||||
|
|
||||||
__all__ = ["DeepseekVLV2Processor", "OvisProcessor"]
|
__all__ = ["DeepseekVLV2Processor", "OvisProcessor"]
|
||||||
|
|||||||
@ -22,6 +22,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from functools import cached_property
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
import PIL
|
import PIL
|
||||||
@ -64,18 +65,29 @@ class OvisProcessor(ProcessorMixin):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
attributes = ["image_processor", "tokenizer"]
|
attributes = ["image_processor", "tokenizer"]
|
||||||
valid_kwargs = ["chat_template"]
|
valid_kwargs = ["chat_template", "image_pad_token", "image_segement_len"]
|
||||||
|
|
||||||
image_processor_class = "AutoImageProcessor"
|
image_processor_class = "AutoImageProcessor"
|
||||||
tokenizer_class = "Qwen2Tokenizer"
|
tokenizer_class = "AutoTokenizer"
|
||||||
|
|
||||||
def __init__(self, image_processor=None, tokenizer=None, chat_template=None, image_pad_token=None, **kwargs):
|
def __init__(
|
||||||
|
self,
|
||||||
|
image_processor=None,
|
||||||
|
tokenizer=None,
|
||||||
|
chat_template=None,
|
||||||
|
image_pad_token=None,
|
||||||
|
image_segment_len=255,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
self.image_token = "<image>"
|
self.image_token = "<image>"
|
||||||
self.image_pad_token = "<|image_pad|>" if image_pad_token is None else image_pad_token
|
self.image_pad_token = image_pad_token
|
||||||
|
self.image_segment_len = image_segment_len
|
||||||
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||||
|
|
||||||
self.image_pad_token_id = self.tokenizer.get_vocab()[self.image_pad_token]
|
@cached_property
|
||||||
self.extra_special_tokens = {
|
def extra_special_tokens(self):
|
||||||
|
image_pad_token_id = self.tokenizer.get_vocab()[self.image_pad_token]
|
||||||
|
extra_special_tokens = {
|
||||||
"image_token": -200,
|
"image_token": -200,
|
||||||
"image_atom": -300,
|
"image_atom": -300,
|
||||||
"image_start": -301,
|
"image_start": -301,
|
||||||
@ -83,8 +95,9 @@ class OvisProcessor(ProcessorMixin):
|
|||||||
"image_col_sep": -303,
|
"image_col_sep": -303,
|
||||||
"image_row_sep": -304,
|
"image_row_sep": -304,
|
||||||
"image_end": -305,
|
"image_end": -305,
|
||||||
'image_pad': self.image_pad_token_id,
|
'image_pad': image_pad_token_id,
|
||||||
}
|
}
|
||||||
|
return extra_special_tokens
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -224,8 +237,14 @@ class OvisProcessor(ProcessorMixin):
|
|||||||
return torch.tensor(batch_token_ids, dtype=torch.long)
|
return torch.tensor(batch_token_ids, dtype=torch.long)
|
||||||
|
|
||||||
def get_image_size(self):
|
def get_image_size(self):
|
||||||
height = self.image_processor.crop_size["height"]
|
size = self.image_processor.size
|
||||||
width = self.image_processor.crop_size["width"]
|
if 'shortest_edge' in size:
|
||||||
|
width = height = size['shortest_edge']
|
||||||
|
elif "height" in size and "width" in size:
|
||||||
|
width = size['width']
|
||||||
|
height = size['height']
|
||||||
|
else:
|
||||||
|
raise ValueError( "Can't parse image size from image_processor config.")
|
||||||
return height, width
|
return height, width
|
||||||
|
|
||||||
def get_token_value(self, tok):
|
def get_token_value(self, tok):
|
||||||
@ -259,8 +278,7 @@ class OvisProcessor(ProcessorMixin):
|
|||||||
for token in image_placeholders:
|
for token in image_placeholders:
|
||||||
padded_placeholder_tokens.append(image_padding_token_id)
|
padded_placeholder_tokens.append(image_padding_token_id)
|
||||||
if token == image_atom_token_id:
|
if token == image_atom_token_id:
|
||||||
# Add 255 padding tokens after each image atom token
|
padded_placeholder_tokens.extend([image_padding_token_id] * self.image_segment_len)
|
||||||
padded_placeholder_tokens.extend([image_padding_token_id] * 255)
|
|
||||||
return padded_placeholder_tokens
|
return padded_placeholder_tokens
|
||||||
|
|
||||||
def preprocess_image(self, image: PIL.Image.Image, max_partition, covering_threshold, convert_to_rgb, return_tensors):
|
def preprocess_image(self, image: PIL.Image.Image, max_partition, covering_threshold, convert_to_rgb, return_tensors):
|
||||||
Loading…
x
Reference in New Issue
Block a user