diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 8c6e7b04de85..48fc24f3447a 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -1045,10 +1045,10 @@ Specified using `--task generate`. * * ✅︎ * ✅︎ -- * `Ovis2ForConditionalGeneration`^ - * Ovis2 +- * `Ovis` + * Ovis2, Ovis1.6 * T + I+ - * `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis2-2B`, etc. + * `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. * * * ✅︎ diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 5c173ab1abb9..c54f328c7a38 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -725,8 +725,8 @@ def run_nvlm_d(questions: list[str], modality: str) -> ModelRequestData: ) -# Ovis2 -def run_ovis2(questions: list[str], modality: str) -> ModelRequestData: +# Ovis +def run_ovis(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" model_name = "AIDC-AI/Ovis2-1B" @@ -737,15 +737,18 @@ def run_ovis2(questions: list[str], modality: str) -> ModelRequestData: max_num_seqs=2, trust_remote_code=True, dtype="half", - hf_overrides={"architectures": ["Ovis2ForConditionalGeneration"]}, limit_mm_per_prompt={modality: 1}, ) - placeholder = "\n" - prompts = [("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" - f"<|im_start|>user\n{placeholder}" - f"{question}<|im_end|>\n" - "<|im_start|>assistant\n") for question in questions] + tokenizer = AutoTokenizer.from_pretrained(model_name, + trust_remote_code=True) + messages = [[{ + 'role': 'user', + 'content': f"\n{question}" + }] for question in questions] + prompts = tokenizer.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True) return ModelRequestData( engine_args=engine_args, @@ -1069,7 +1072,7 @@ model_example_map = { "llama4": run_llama4, "molmo": run_molmo, "NVLM_D": run_nvlm_d, - "ovis2": run_ovis2, + "ovis": run_ovis, "paligemma": run_paligemma, "paligemma2": run_paligemma2, "phi3_v": run_phi3v, diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index 48d590b05b06..20a8e635e322 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -436,8 +436,8 @@ def load_nvlm_d(question: str, image_urls: list[str]) -> ModelRequestData: ) -# Ovis2 -def load_ovis2(question: str, image_urls: list[str]) -> ModelRequestData: +# Ovis +def load_ovis(question: str, image_urls: list[str]) -> ModelRequestData: model_name = "AIDC-AI/Ovis2-1B" engine_args = EngineArgs( @@ -447,15 +447,17 @@ def load_ovis2(question: str, image_urls: list[str]) -> ModelRequestData: trust_remote_code=True, dtype="half", limit_mm_per_prompt={"image": len(image_urls)}, - hf_overrides={"architectures": ["Ovis2ForConditionalGeneration"]}, ) - placeholder = '\n'.join( - [f'Image {i+1}: ' for i in range(len(image_urls))]) + '\n' - prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" - f"<|im_start|>user\n{placeholder}" - f"{question}<|im_end|>\n" - "<|im_start|>assistant\n") + placeholders = "\n".join(f"Image-{i}: \n" + for i, _ in enumerate(image_urls, start=1)) + messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}] + + tokenizer = AutoTokenizer.from_pretrained(model_name, + trust_remote_code=True) + prompt = tokenizer.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True) return ModelRequestData( engine_args=engine_args, @@ -713,7 +715,7 @@ model_example_map = { "mistral3": load_mistral3, "mllama": load_mllama, "NVLM_D": load_nvlm_d, - "ovis2": load_ovis2, + "ovis": load_ovis, "phi3_v": load_phi3v, "phi4_mm": load_phi4mm, "pixtral_hf": load_pixtral_hf, diff --git a/tests/conftest.py b/tests/conftest.py index fa979f1093be..c5700179c228 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -355,10 +355,16 @@ class HfRunner: **model_kwargs, ) + # in case some unquantized custom models are not in same dtype + if (getattr(model, "quantization_method", None) is None + and any(p.dtype != self.dtype + for p in model.parameters())): + model = model.to(dtype=self.dtype) + if (getattr(model, "quantization_method", None) != "bitsandbytes" and len({p.device for p in model.parameters()}) < 2): - model = model.to(self.device) + model = model.to(device=self.device) self.model = model diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index 6e915a9f6005..dead2edc4fa3 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -476,6 +476,31 @@ VLM_TEST_SETTINGS = { max_num_seqs=2, patch_hf_runner=model_utils.molmo_patch_hf_runner, ), + "ovis1_6-gemma2": VLMTestInfo( + models=["AIDC-AI/Ovis1.6-Gemma2-9B"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"user\n{img_prompt}\nmodel\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "\n", # noqa: E501 + max_model_len=4096, + max_num_seqs=2, + dtype="half", + # use sdpa mode for hf runner since ovis2 didn't work with flash_attn + hf_model_kwargs={"llm_attn_implementation": "sdpa"}, + patch_hf_runner=model_utils.ovis_patch_hf_runner, + marks=[large_gpu_mark(min_gb=32)], + ), + "ovis1_6": VLMTestInfo( + models=["AIDC-AI/Ovis1.6-Llama3.2-3B"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful and honest multimodal assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "\n", # noqa: E501 + max_model_len=4096, + max_num_seqs=2, + dtype="half", + # use sdpa mode for hf runner since ovis2 didn't work with flash_attn + hf_model_kwargs={"llm_attn_implementation": "sdpa"}, + patch_hf_runner=model_utils.ovis_patch_hf_runner, + ), "ovis2": VLMTestInfo( models=["AIDC-AI/Ovis2-1B"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), @@ -486,7 +511,7 @@ VLM_TEST_SETTINGS = { dtype="half", # use sdpa mode for hf runner since ovis2 didn't work with flash_attn hf_model_kwargs={"llm_attn_implementation": "sdpa"}, - patch_hf_runner=model_utils.ovis2_patch_hf_runner, + patch_hf_runner=model_utils.ovis_patch_hf_runner, ), "phi3v": VLMTestInfo( models=["microsoft/Phi-3.5-vision-instruct"], diff --git a/tests/models/multimodal/generation/vlm_utils/model_utils.py b/tests/models/multimodal/generation/vlm_utils/model_utils.py index f0f4ed989241..e31408d6063f 100644 --- a/tests/models/multimodal/generation/vlm_utils/model_utils.py +++ b/tests/models/multimodal/generation/vlm_utils/model_utils.py @@ -678,12 +678,8 @@ def molmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner: return hf_model -def ovis2_patch_hf_runner(hf_model: HfRunner) -> HfRunner: +def ovis_patch_hf_runner(hf_model: HfRunner) -> HfRunner: """Patches and returns an instance of the HfRunner to use for Ovis2.""" - hf_model.model.visual_tokenizer.to(hf_model.dtype) - hf_model.model.vte.to(hf_model.dtype) - hf_model.model.llm.to(hf_model.dtype) - hf_model.model.get_output_embeddings = lambda: \ hf_model.model.llm.get_output_embeddings() @@ -691,7 +687,16 @@ def ovis2_patch_hf_runner(hf_model: HfRunner) -> HfRunner: text_tokenizer = hf_model.model.get_text_tokenizer() images = [images] if isinstance(images, Image) else images - text = text.split("<|im_start|>user\n")[1].split("<|im_end|>\n")[0] + prompt_start_and_end = { + "qwen2": ("<|im_start|>user\n", "<|im_end|>\n"), + "llama": + ("<|start_header_id|>user<|end_header_id|>\n\n", "<|eot_id|>"), + "gemma2": ("user\n", "\n"), + } + for start, end in prompt_start_and_end.values(): + if start in text and end in text: + text = text.split(start)[1].split(end)[0] + break prompt, input_ids, pixel_values = hf_model.model.preprocess_inputs( text_or_conversations=text, images=images) diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 772a2db3e48a..e6b70a4438e9 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -146,7 +146,8 @@ def _test_processing_correctness_hf( batch_idx: int, ignore_mm_keys: Optional[set[str]] = None, ): - if model_config.hf_config.model_type in ("mllama", "whisper", "ultravox"): + if model_config.hf_config.model_type in ("mllama", "ovis", "ultravox", + "whisper"): # For some multimodal models, tokenizer will always add bos_token # at the beginning of prompt by default, causing hf_processor outputs # incorrect token ids. So we need use `add_special_tokens=False` here @@ -274,6 +275,8 @@ def _test_processing_correctness_mistral( "allenai/Molmo-7B-D-0924", "allenai/Molmo-7B-O-0924", "nvidia/NVLM-D-72B", + "AIDC-AI/Ovis1.6-Gemma2-9B", + "AIDC-AI/Ovis1.6-Llama3.2-3B", "AIDC-AI/Ovis2-1B", "google/paligemma-3b-mix-224", "google/paligemma2-3b-ft-docci-448", diff --git a/tests/models/registry.py b/tests/models/registry.py index a1f2edac02b9..683d15d508ec 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -355,9 +355,9 @@ _MULTIMODAL_EXAMPLE_MODELS = { max_transformers_version="4.48", transformers_version_reason="Use of deprecated imports which have been removed.", # noqa: E501 extras={"phi3.5": "microsoft/Phi-3.5-vision-instruct"}), # noqa: E501 - "Ovis2ForConditionalGeneration": _HfExamplesInfo("AIDC-AI/Ovis2-1B", - trust_remote_code=True, - hf_overrides={"architectures": ["Ovis2ForConditionalGeneration"]}), # noqa: E501 + "Ovis": _HfExamplesInfo("AIDC-AI/Ovis2-1B", trust_remote_code=True, + extras={"1.6-llama": "AIDC-AI/Ovis1.6-Llama3.2-3B", + "1.6-gemma": "AIDC-AI/Ovis1.6-Gemma2-9B"}), # noqa: E501 "Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct", trust_remote_code=True), "PixtralForConditionalGeneration": _HfExamplesInfo("mistralai/Pixtral-12B-2409", # noqa: E501 diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 38fe98572178..db43b2dd295d 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -512,7 +512,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): hf_config.image_token_index) if model_type in ("aya_vision", "chameleon", "deepseek_vl_v2", - "internvl_chat", "ovis2", "skywork_chat", + "internvl_chat", "ovis", "skywork_chat", "NVLM_D", "h2ovl_chat", "idefics3", "smolvlm"): return "" if model_type in ("mllama", "llama4"): diff --git a/vllm/model_executor/models/aimv2.py b/vllm/model_executor/models/aimv2.py index 730e770dc3d6..aefd6c973755 100644 --- a/vllm/model_executor/models/aimv2.py +++ b/vllm/model_executor/models/aimv2.py @@ -5,129 +5,14 @@ from typing import Optional import torch -from torch import nn, softmax +import torch.nn as nn from torch.nn import functional as F -from torch.nn.functional import gumbel_softmax, pad from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.transformers_utils.configs.ovis2 import (AIMv2Config, - Aimv2VisualTokenizerConfig) - -IMAGE_INDICATOR_IDS = [-301, -302, -303, -304, - -305] # kept for vocab prefixed tokens - - -def st_argmax(y_soft: torch.Tensor, dim: int): # straight-through softmax - index = y_soft.max(dim, keepdim=True)[1] - y_hard = torch.zeros_like( - y_soft, memory_format=torch.legacy_contiguous_format).scatter_( - dim, index, 1.0) - ret = y_hard - y_soft.detach() + y_soft - return ret - - -class Aimv2VisualTokenizer(torch.nn.Module): - - def __init__(self, - config: Aimv2VisualTokenizerConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - **kwargs): - super().__init__() - self.config = config - self.backbone = AIMv2Model( - config=config.backbone_config, # noqa - quant_config=quant_config, - prefix=f"{prefix}.visual_tokenizer") - # reserved tokens for IMAGE_INDICATORS - head_dim = config.vocab_size - len(IMAGE_INDICATOR_IDS) - self.head = torch.nn.Sequential( - ReplicatedLinear( - config.backbone_config.hidden_size * config.hidden_stride * - config.hidden_stride, - head_dim, - bias=False, - ), torch.nn.LayerNorm(head_dim)) - - @property - def dtype(self): - return self.backbone.dtype - - @property - def device(self): - return self.backbone.device - - def tokenize(self, logits): - if self.config.tokenize_function == 'softmax': - tokens = softmax(logits, dim=-1) - elif self.config.tokenize_function == 'gumbel_argmax': - tokens = gumbel_softmax(logits, tau=self.config.tau, hard=True) - elif self.config.tokenize_function == 'st_argmax': - tokens = st_argmax(logits, dim=-1) - else: - raise ValueError( - 'Invalid `max_type`, expected softmax or gumbel_argmax ' - f'or st_argmax, but got {self.config.tokenize_function}') - return tokens - - def encode(self, pixel_values): - features = self.backbone(pixel_values) - if self.config.drop_cls_token: - features = features[:, 1:, :] - - # merge number of `hidden_stride * hidden_stride` hidden states together - # to reduce token sequence length - # e.g., for hidden_stride=2, this leads to a token length reduction: - # 1024 -> 256 for aimv2 - if self.config.hidden_stride > 1: - # this `d` maybe different from the above `d`` - n, L, d = features.shape - sqrt_l = int(L**0.5) - assert sqrt_l**2 == L, ( - "The token sequence length should be a perfect square.") - features = features.reshape(n, sqrt_l, sqrt_l, d) - pl = (self.config.hidden_stride - - (sqrt_l % - self.config.hidden_stride)) % self.config.hidden_stride - features = pad(features, (0, 0, 0, pl, 0, pl), "constant", 0) - sqrt_l += pl - features = features.reshape(n, sqrt_l // self.config.hidden_stride, - self.config.hidden_stride, - sqrt_l // self.config.hidden_stride, - self.config.hidden_stride, d) - # [n, sqrt_l/hs, sqrt_l/hs, hs, hs, d] - features = features.permute(0, 1, 3, 2, 4, 5) - # [n, sqrt_l/hs, sqrt_l/hs, hs*hs*d] - features = features.flatten(3) - # [n, sqrt_l/hs*sqrt_l/hs, hs*hs*d] - features = features.reshape( - n, -1, - self.config.hidden_stride * self.config.hidden_stride * d) - - return features - - def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: - """[BatchSize, ImageShape] -> [BatchSize, Token, VocabSize]""" - features = self.encode(pixel_values) - logits, _ = self.head[0]( - features) # we spllit the sequncial here for not throwing an error - logits = self.head[1](logits) - tokens = self.tokenize(logits) - # tokens' shape is [BatchSize, #Token, VocabSize-5], so padding with - # [BatchSize, #Token, 5], after which, tokens' shape should become - # [BatchSize, #Token, VocabSize] - batch_size, token_len, _ = tokens.shape - padding_tensor = torch.zeros(size=(batch_size, token_len, - len(IMAGE_INDICATOR_IDS)), - dtype=tokens.dtype, - device=tokens.device, - layout=tokens.layout, - requires_grad=False) - tokens = torch.cat((tokens, padding_tensor), dim=2) - return tokens +from vllm.transformers_utils.configs.ovis import AIMv2Config class AIMv2SwiGLUFFN(nn.Module): @@ -302,14 +187,6 @@ class AIMv2Model(torch.nn.Module): quant_config=quant_config, prefix=f"{prefix}.trunk") - @property - def dtype(self): - return self.trunk.blocks[0].attn.qkv.weight.dtype - - @property - def device(self): - return self.trunk.blocks[0].attn.qkv.device - def forward( self, pixel_values: torch.Tensor, diff --git a/vllm/model_executor/models/ovis2.py b/vllm/model_executor/models/ovis.py similarity index 59% rename from vllm/model_executor/models/ovis2.py rename to vllm/model_executor/models/ovis.py index 67cc86e7fc82..5204c751216f 100644 --- a/vllm/model_executor/models/ovis2.py +++ b/vllm/model_executor/models/ovis.py @@ -15,17 +15,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch Ovis2 model.""" +""" PyTorch Ovis model.""" +import math from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) import torch import torch.nn as nn from torch import Tensor -from transformers import BatchFeature +from torch.nn.functional import gumbel_softmax, pad, softmax +from transformers import BaseImageProcessor, BatchFeature from vllm.config import VllmConfig -from vllm.model_executor.models.aimv2 import Aimv2VisualTokenizer +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.models.aimv2 import AIMv2Model +from vllm.model_executor.models.siglip import SiglipVisionModel from vllm.model_executor.models.utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix) @@ -38,19 +44,160 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.configs.ovis2 import OvisConfig -from vllm.transformers_utils.processors.ovis2 import OvisProcessor +from vllm.transformers_utils.configs.ovis import (BaseVisualTokenizerConfig, + OvisConfig) +from vllm.transformers_utils.processors.ovis import OvisProcessor from .interfaces import MultiModalEmbeddings, SupportsMultiModal from .utils import merge_multimodal_embeddings # Cannot find the following number from hf config. IMAGE_TOKEN = "" -IMAGE_PAD_TOKEN_ID = 151655 -NUMBER_OF_TOKEN_TO_RESERVE_FOR_SEGMENT = 256 +IMAGE_INDICATOR_IDS = [-301, -302, -303, -304, -305] + +IMAGE_PAD_TOKEN_MAP = { + "gemma2": "", + "llama": "<|reserved_special_token_0|>", + "qwen2": "<|image_pad|>", +} +IMAGE_PAD_TOKEN_ID_MAP = { + "gemma2": 7, + "llama": 128002, + "qwen2": 151655, +} -class Ovis2ImagePatchInputs(TypedDict): +def st_argmax(y_soft: torch.Tensor, dim: int): # straight-through softmax + index = y_soft.argmax(dim, keepdim=True) + return torch.zeros_like( + y_soft, + memory_format=torch.legacy_contiguous_format, + ).scatter_(dim, index, 1.0) + + +class VisualTokenizer(torch.nn.Module): + + def __init__( + self, + config: BaseVisualTokenizerConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.backbone = self._init_backbone( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.backbone", + ) + # reserved tokens for IMAGE_INDICATORS + head_dim = config.vocab_size - len(IMAGE_INDICATOR_IDS) + self.head = torch.nn.Sequential( + ReplicatedLinear( + config.backbone_config.hidden_size * config.hidden_stride * + config.hidden_stride, + head_dim, + bias=False, + return_bias=False, + ), torch.nn.LayerNorm(head_dim)) + + def _init_backbone( + self, + config: BaseVisualTokenizerConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + model_type = config.backbone_config.model_type + if model_type == "aimv2": + return AIMv2Model( + config=config.backbone_config, + quant_config=quant_config, + prefix=prefix, + ) + elif model_type == "siglip_vision_model": + return SiglipVisionModel( + config=config.backbone_config, + quant_config=quant_config, + prefix=prefix, + ) + raise ValueError( + f"Unsupported visual tokenizer model_type: {model_type}") + + @property + def dtype(self): + return next(self.head.parameters()).dtype + + @property + def device(self): + return next(self.head.parameters()).device + + def tokenize(self, logits): + if self.config.tokenize_function == 'softmax': + tokens = softmax(logits, dim=-1) + elif self.config.tokenize_function == 'gumbel_argmax': + tokens = gumbel_softmax(logits, tau=self.config.tau, hard=True) + elif self.config.tokenize_function == 'st_argmax': + tokens = st_argmax(logits, dim=-1) + else: + raise ValueError( + 'Invalid `max_type`, expected softmax or gumbel_argmax ' + f'or st_argmax, but got {self.config.tokenize_function}') + return tokens + + def encode(self, pixel_values): + features = self.backbone(pixel_values) + if self.config.drop_cls_token: + features = features[:, 1:, :] + + # merge number of `hidden_stride * hidden_stride` hidden states together + # to reduce token sequence length + # e.g., for hidden_stride=2, this leads to a token length reduction: + # 1024 -> 256 for aimv2 + if self.config.hidden_stride > 1: + # this `d` maybe different from the above `d`` + n, L, d = features.shape + sqrt_l = int(L**0.5) + assert sqrt_l**2 == L, ( + "The token sequence length should be a perfect square.") + features = features.reshape(n, sqrt_l, sqrt_l, d) + pl = (self.config.hidden_stride - + (sqrt_l % + self.config.hidden_stride)) % self.config.hidden_stride + features = pad(features, (0, 0, 0, pl, 0, pl), "constant", 0) + sqrt_l += pl + features = features.reshape(n, sqrt_l // self.config.hidden_stride, + self.config.hidden_stride, + sqrt_l // self.config.hidden_stride, + self.config.hidden_stride, d) + # [n, sqrt_l/hs, sqrt_l/hs, hs, hs, d] + features = features.permute(0, 1, 3, 2, 4, 5) + # [n, sqrt_l/hs, sqrt_l/hs, hs*hs*d] + features = features.flatten(3) + # [n, sqrt_l/hs*sqrt_l/hs, hs*hs*d] + features = features.reshape( + n, -1, + self.config.hidden_stride * self.config.hidden_stride * d) + + return features + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + """[BatchSize, ImageShape] -> [BatchSize, Token, VocabSize]""" + features = self.encode(pixel_values) + logits = self.head(features) + tokens = self.tokenize(logits) + # tokens' shape is [BatchSize, #Token, VocabSize-5], so padding with + # [BatchSize, #Token, 5], after which, tokens' shape should become + # [BatchSize, #Token, VocabSize] + tokens = torch.nn.functional.pad( + tokens, + (0, len(IMAGE_INDICATOR_IDS)), + mode="constant", + value=0, + ) + return tokens + + +class OvisImagePatchInputs(TypedDict): type: Literal["image_patches"] flat_data: torch.Tensor """ @@ -92,31 +239,50 @@ class VisualEmbedding(torch.nn.Embedding): return self.weight.dtype -class Ovis2ProcessingInfo(BaseProcessingInfo): +class OvisProcessingInfo(BaseProcessingInfo): def get_hf_config(self): return self.ctx.get_hf_config(OvisConfig) def get_hf_processor(self, **kwargs): - return self.ctx.get_hf_processor(OvisProcessor) + return self.ctx.get_hf_processor( + OvisProcessor, + image_pad_token=self.get_image_pad_token(), + image_segment_len=self.get_image_segment_len(), + ) - def get_image_processor(self) -> OvisProcessor: + def get_image_segment_len(self) -> int: + visual_tokenizer_config = self.get_hf_config().visual_tokenizer_config + image_size = visual_tokenizer_config.backbone_config.image_size + patch_size = visual_tokenizer_config.backbone_config.patch_size + hidden_stride = visual_tokenizer_config.hidden_stride + patch_grid_length = math.ceil(image_size / patch_size) + assert patch_grid_length % hidden_stride == 0, ( + f"patch_grid_length {patch_grid_length} is not divisible by " + f"hidden_stride {hidden_stride}") + # minus 1 for presented image token + return (patch_grid_length // hidden_stride)**2 - 1 + + def get_image_pad_token(self) -> str: + hf_text_config = self.get_hf_config().get_text_config() + text_model_type = hf_text_config.model_type + return IMAGE_PAD_TOKEN_MAP.get(text_model_type) + + def get_image_processor(self) -> BaseImageProcessor: return self.get_hf_processor().image_processor # type: ignore def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: - return { # 32k is model token limit at the moment - "image": - self.get_hf_config().multimodal_max_length // - ((9 + 1) * NUMBER_OF_TOKEN_TO_RESERVE_FOR_SEGMENT) - } + return {"image": None} def get_image_size_with_most_features(self) -> ImageSize: - image_processor = self.get_image_processor() - return ImageSize(width=image_processor.size['shortest_edge'] * 9 * 2, - height=image_processor.size['shortest_edge'] * 9 * 2) + height, width = self.get_hf_processor().get_image_size() + hs = self.get_hf_config().visual_tokenizer_config.hidden_stride + # NOTE(Isotr0py): 9 is `max_partion` hardcoded in original code + # https://huggingface.co/AIDC-AI/Ovis2-1B/blob/main/modeling_ovis.py#L96 + return ImageSize(width=width * hs * 9, height=height * hs * 9) -class Ovis2DummyInputsBuilder(BaseDummyInputsBuilder[Ovis2ProcessingInfo]): +class OvisDummyInputsBuilder(BaseDummyInputsBuilder[OvisProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -141,7 +307,7 @@ class Ovis2DummyInputsBuilder(BaseDummyInputsBuilder[Ovis2ProcessingInfo]): return mm_data -class Ovis2MultiModalProcessor(BaseMultiModalProcessor[Ovis2ProcessingInfo]): +class OvisMultiModalProcessor(BaseMultiModalProcessor[OvisProcessingInfo]): def image_indicators_to_visual_tokens( self, @@ -165,9 +331,9 @@ class Ovis2MultiModalProcessor(BaseMultiModalProcessor[Ovis2ProcessingInfo]): mm_kwargs: Mapping[str, object], ) -> BatchFeature: if not mm_data: - # # Avoid warning from HF logger for text-only input - prompt_ids = self.info.get_tokenizer().encode(prompt) - # prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) nope + # Avoid warning from HF logger for text-only input + tokenizer = self.info.get_tokenizer() + prompt_ids = tokenizer.encode(prompt, add_special_tokens=False) return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") processed_outputs = super()._call_hf_processor( @@ -226,10 +392,10 @@ class Ovis2MultiModalProcessor(BaseMultiModalProcessor[Ovis2ProcessingInfo]): ] -@MULTIMODAL_REGISTRY.register_processor(Ovis2MultiModalProcessor, - info=Ovis2ProcessingInfo, - dummy_inputs=Ovis2DummyInputsBuilder) -class Ovis2ForConditionalGeneration(nn.Module, SupportsMultiModal): +@MULTIMODAL_REGISTRY.register_processor(OvisMultiModalProcessor, + info=OvisProcessingInfo, + dummy_inputs=OvisDummyInputsBuilder) +class Ovis(nn.Module, SupportsMultiModal): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -242,24 +408,25 @@ class Ovis2ForConditionalGeneration(nn.Module, SupportsMultiModal): prefix=maybe_prefix(prefix, "llm"), ) - self.visual_tokenizer = Aimv2VisualTokenizer( + self.visual_tokenizer = VisualTokenizer( config=config.visual_tokenizer_config, quant_config=quant_config, prefix=f"{prefix}.visual_tokenizer", - image_processor_name_or_path=config.visual_tokenizer_config. - backbone_config.name_or_path, ) self.vte = VisualEmbedding( self.config.visual_tokenizer_config.vocab_size, self.config.hidden_size) + text_model_type = self.config.get_text_config().model_type + self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type] + # TODO(Isotr0py): PP support # self.make_empty_intermediate_tensors = ( # self.language_model.make_empty_intermediate_tensors) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[Ovis2ImagePatchInputs]: + self, **kwargs: object) -> Optional[OvisImagePatchInputs]: pixel_values = kwargs.pop("pixel_values", None) indicator_tokens = kwargs.pop("indicator_tokens", None) @@ -275,7 +442,7 @@ class Ovis2ForConditionalGeneration(nn.Module, SupportsMultiModal): raise ValueError("Incorrect type of indicator_tokens. " f"Got type: {type(pixel_values)}") - return Ovis2ImagePatchInputs( + return OvisImagePatchInputs( type="image_patches", flat_data=flatten_bn(flatten_bn(pixel_values), concat=True), patches_per_image=[ @@ -288,7 +455,7 @@ class Ovis2ForConditionalGeneration(nn.Module, SupportsMultiModal): raise AssertionError("This line should be unreachable.") def _process_image_input( - self, image_input: Ovis2ImagePatchInputs) -> MultiModalEmbeddings: + self, image_input: OvisImagePatchInputs) -> MultiModalEmbeddings: image_patches_flat = image_input["flat_data"] patches_per_image = image_input["patches_per_image"] indicator_tokens = image_input["indicator_tokens"] @@ -338,7 +505,7 @@ class Ovis2ForConditionalGeneration(nn.Module, SupportsMultiModal): if multimodal_embeddings is not None: inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, multimodal_embeddings, - [IMAGE_PAD_TOKEN_ID]) + self.image_pad_token_id) return inputs_embeds def forward( @@ -375,8 +542,7 @@ class Ovis2ForConditionalGeneration(nn.Module, SupportsMultiModal): hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.llm.logits_processor(self.llm.lm_head, hidden_states, - sampling_metadata) + logits = self.llm.compute_logits(hidden_states, sampling_metadata) return logits def load_weights(self, weights: Iterable[Tuple[str, diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index aef4566193c8..c5414e129dd1 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -195,7 +195,7 @@ _MULTIMODAL_MODELS = { "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"), # noqa: E501 "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"), "NVLM_D": ("nvlm_d", "NVLM_D_Model"), - "Ovis2ForConditionalGeneration": ("ovis2", "Ovis2ForConditionalGeneration"), + "Ovis": ("ovis", "Ovis"), "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501 "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501 diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index db3efafeef96..ed10c22c84f0 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -23,7 +23,7 @@ from vllm.transformers_utils.configs.moonvit import MoonViTConfig from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.nemotron import NemotronConfig from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config -from vllm.transformers_utils.configs.ovis2 import OvisConfig +from vllm.transformers_utils.configs.ovis import OvisConfig from vllm.transformers_utils.configs.skyworkr1v import SkyworkR1VChatConfig from vllm.transformers_utils.configs.solar import SolarConfig from vllm.transformers_utils.configs.telechat2 import Telechat2Config diff --git a/vllm/transformers_utils/configs/ovis2.py b/vllm/transformers_utils/configs/ovis.py similarity index 93% rename from vllm/transformers_utils/configs/ovis2.py rename to vllm/transformers_utils/configs/ovis.py index 437a16e778c2..0ec224214f06 100644 --- a/vllm/transformers_utils/configs/ovis2.py +++ b/vllm/transformers_utils/configs/ovis.py @@ -123,6 +123,19 @@ class Aimv2VisualTokenizerConfig(BaseVisualTokenizerConfig): self.backbone_kwargs['num_hidden_layers'] = self.depths[0] +class SiglipVisualTokenizerConfig(BaseVisualTokenizerConfig): + model_type = "siglip_visual_tokenizer" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + if self.drop_cls_token: + self.drop_cls_token = False + if self.depths: + assert len(self.depths) == 1 + self.backbone_kwargs['num_hidden_layers'] = self.depths[0] + + +AutoConfig.register("siglip_visual_tokenizer", SiglipVisualTokenizerConfig) AutoConfig.register("aimv2_visual_tokenizer", Aimv2VisualTokenizerConfig) diff --git a/vllm/transformers_utils/processors/__init__.py b/vllm/transformers_utils/processors/__init__.py index 2e9cf3e4d90b..2bd9ab1f099b 100644 --- a/vllm/transformers_utils/processors/__init__.py +++ b/vllm/transformers_utils/processors/__init__.py @@ -2,6 +2,6 @@ from vllm.transformers_utils.processors.deepseek_vl2 import ( DeepseekVLV2Processor) -from vllm.transformers_utils.processors.ovis2 import OvisProcessor +from vllm.transformers_utils.processors.ovis import OvisProcessor __all__ = ["DeepseekVLV2Processor", "OvisProcessor"] diff --git a/vllm/transformers_utils/processors/ovis2.py b/vllm/transformers_utils/processors/ovis.py similarity index 94% rename from vllm/transformers_utils/processors/ovis2.py rename to vllm/transformers_utils/processors/ovis.py index a633256ec12c..48e786792cf5 100644 --- a/vllm/transformers_utils/processors/ovis2.py +++ b/vllm/transformers_utils/processors/ovis.py @@ -22,6 +22,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import cached_property from typing import List, Union import PIL @@ -32,7 +33,7 @@ from transformers.processing_utils import (ProcessingKwargs, ProcessorMixin, Unpack) from transformers.tokenization_utils_base import PreTokenizedInput, TextInput -__all__ = [ 'OvisProcessor'] +__all__ = ['OvisProcessor'] IGNORE_ID = -100 class OvisProcessorKwargs(ProcessingKwargs, total=False): # type: ignore[call-arg] @@ -64,18 +65,29 @@ class OvisProcessor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - valid_kwargs = ["chat_template"] + valid_kwargs = ["chat_template", "image_pad_token", "image_segement_len"] image_processor_class = "AutoImageProcessor" - tokenizer_class = "Qwen2Tokenizer" + tokenizer_class = "AutoTokenizer" - def __init__(self, image_processor=None, tokenizer=None, chat_template=None, image_pad_token=None, **kwargs): + def __init__( + self, + image_processor=None, + tokenizer=None, + chat_template=None, + image_pad_token=None, + image_segment_len=255, + **kwargs, + ): self.image_token = "" - self.image_pad_token = "<|image_pad|>" if image_pad_token is None else image_pad_token + self.image_pad_token = image_pad_token + self.image_segment_len = image_segment_len super().__init__(image_processor, tokenizer, chat_template=chat_template) - self.image_pad_token_id = self.tokenizer.get_vocab()[self.image_pad_token] - self.extra_special_tokens = { + @cached_property + def extra_special_tokens(self): + image_pad_token_id = self.tokenizer.get_vocab()[self.image_pad_token] + extra_special_tokens = { "image_token": -200, "image_atom": -300, "image_start": -301, @@ -83,8 +95,9 @@ class OvisProcessor(ProcessorMixin): "image_col_sep": -303, "image_row_sep": -304, "image_end": -305, - 'image_pad': self.image_pad_token_id, + 'image_pad': image_pad_token_id, } + return extra_special_tokens def __call__( self, @@ -224,8 +237,14 @@ class OvisProcessor(ProcessorMixin): return torch.tensor(batch_token_ids, dtype=torch.long) def get_image_size(self): - height = self.image_processor.crop_size["height"] - width = self.image_processor.crop_size["width"] + size = self.image_processor.size + if 'shortest_edge' in size: + width = height = size['shortest_edge'] + elif "height" in size and "width" in size: + width = size['width'] + height = size['height'] + else: + raise ValueError( "Can't parse image size from image_processor config.") return height, width def get_token_value(self, tok): @@ -259,8 +278,7 @@ class OvisProcessor(ProcessorMixin): for token in image_placeholders: padded_placeholder_tokens.append(image_padding_token_id) if token == image_atom_token_id: - # Add 255 padding tokens after each image atom token - padded_placeholder_tokens.extend([image_padding_token_id] * 255) + padded_placeholder_tokens.extend([image_padding_token_id] * self.image_segment_len) return padded_placeholder_tokens def preprocess_image(self, image: PIL.Image.Image, max_partition, covering_threshold, convert_to_rgb, return_tensors):