From 1282bd812ea4e1511378bad5b918d609280d2b89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B1=AA=E5=BF=97=E9=B9=8F?= Date: Tue, 3 Jun 2025 13:13:13 +0800 Subject: [PATCH] Add tarsier model support (#18985) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 汪志鹏 --- docs/models/supported_models.md | 1 + examples/offline_inference/vision_language.py | 20 + .../vision_language_multi_image.py | 21 + .../multimodal/processing/test_common.py | 1 + tests/models/registry.py | 2 + vllm/model_executor/models/registry.py | 1 + vllm/model_executor/models/tarsier.py | 643 ++++++++++++++++++ 7 files changed, 689 insertions(+) create mode 100644 vllm/model_executor/models/tarsier.py diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index b60fefdda2793..f2090fe3971e9 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -550,6 +550,7 @@ Specified using `--task generate`. | `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + IE+ + VE+ + A+ | `Qwen/Qwen2.5-Omni-7B` | | ✅︎ | ✅︎\* | | `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ | | `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ | +| `TarsierForConditionalGeneration` | Tarsier | T + IE+ | `omni-search/Tarsier-7b`,`omni-search/Tarsier-34b` | | ✅︎ | ✅︎ | ^ You need to set the architecture name via `--hf-overrides` to match the one in vLLM.     • For example, to use DeepSeek-VL2 series models: diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index f0504501639d2..2ef87f4f4696e 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -333,6 +333,25 @@ def run_smolvlm(questions: list[str], modality: str) -> ModelRequestData: ) +# omni-research/Tarsier-7b +def run_tarsier(questions: list[str], modality: str) -> ModelRequestData: + assert modality == "image" + model_name = "omni-research/Tarsier-7b" + + engine_args = EngineArgs( + model=model_name, + trust_remote_code=True, + max_model_len=4096, + limit_mm_per_prompt={modality: 1}, + ) + prompts = [(f"USER: \n{question} ASSISTANT:") for question in questions] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # InternVL def run_internvl(questions: list[str], modality: str) -> ModelRequestData: model_name = "OpenGVLab/InternVL3-2B" @@ -1091,6 +1110,7 @@ model_example_map = { "qwen2_5_omni": run_qwen2_5_omni, "skywork_chat": run_skyworkr1v, "smolvlm": run_smolvlm, + "tarsier": run_tarsier, } diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index e776ff7fe6aec..7ce28c5a4f09f 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -691,6 +691,26 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData: ) +def load_tarsier(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "omni-research/Tarsier-7b" + + engine_args = EngineArgs( + model=model_name, + trust_remote_code=True, + max_model_len=4096, + limit_mm_per_prompt={"image": len(image_urls)}, + ) + + prompt = f"USER: {'' * len(image_urls)}\n{question}\n ASSISTANT:" + image_data = [fetch_image(url) for url in image_urls] + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image_data=image_data, + ) + + model_example_map = { "aria": load_aria, "aya_vision": load_aya_vision, @@ -712,6 +732,7 @@ model_example_map = { "qwen2_vl": load_qwen2_vl, "qwen2_5_vl": load_qwen2_5_vl, "smolvlm": load_smolvlm, + "tarsier": load_tarsier, } diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index d7f950c23d954..2377fef820ed1 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -282,6 +282,7 @@ def _test_processing_correctness_one( "Skywork/Skywork-R1V-38B", "fixie-ai/ultravox-v0_5-llama-3_2-1b", "openai/whisper-large-v3", + "omni-research/Tarsier-7b", ]) @pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) @pytest.mark.parametrize("num_batches", [32]) diff --git a/tests/models/registry.py b/tests/models/registry.py index fe49d2427c744..182a9668ebef1 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -406,6 +406,8 @@ _MULTIMODAL_EXAMPLE_MODELS = { "SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501 "UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501 trust_remote_code=True), + "TarsierForConditionalGeneration": _HfExamplesInfo("omni-research/Tarsier-7b", # noqa: E501 + hf_overrides={"architectures": ["TarsierForConditionalGeneration"]}), # noqa: E501 # [Encoder-decoder] # Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer # Therefore, we borrow the BartTokenizer from the original Bart model diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 8efd4825beea9..fcef457a78291 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -211,6 +211,7 @@ _MULTIMODAL_MODELS = { "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501 "UltravoxModel": ("ultravox", "UltravoxModel"), "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"), + "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501 # [Encoder-decoder] "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501 "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501 diff --git a/vllm/model_executor/models/tarsier.py b/vllm/model_executor/models/tarsier.py new file mode 100644 index 0000000000000..5aa3ddabc19ec --- /dev/null +++ b/vllm/model_executor/models/tarsier.py @@ -0,0 +1,643 @@ +# SPDX-License-Identifier: Apache-2.0 + +import math +from collections.abc import Iterable, Mapping, Sequence +from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar, + Union, cast) + +import torch +import torch.nn as nn +from transformers import BatchFeature, CLIPVisionConfig +from transformers import LlavaConfig as HfLlavaConfig +from transformers import PretrainedConfig, SiglipVisionConfig +from transformers.image_utils import ImageInput, get_image_size, to_numpy_array +from transformers.models.llava import LlavaProcessor +from transformers.processing_utils import (ProcessingKwargs, Unpack, + _validate_images_text_input_order) +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput + +from vllm.config import VllmConfig +from vllm.inputs import InputProcessingContext +from vllm.jsontree import json_map_leaves +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.models.llava import LlavaDummyInputsBuilder +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, + ImageSize, MultiModalDataItems) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, ProcessingCache, + PromptReplacement, PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors + +from .clip import CLIPVisionModel +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .siglip import SiglipVisionModel +from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, + maybe_prefix, merge_multimodal_embeddings) +from .vision import VisionEncoderInfo, get_vision_encoder_info + + +class TarsierImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + pixel_values: torch.Tensor + + +class TarsierImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: torch.Tensor + + +TarsierImageInputs = Union[TarsierImagePixelInputs, + TarsierImageEmbeddingInputs] + + +class TarsierHfConfig(Protocol): # Based on the Tarsier's LlavaConfig + vision_config: Final[PretrainedConfig] + text_config: Final[PretrainedConfig] # Added from Tarsier's LlavaConfig + image_token_index: Final[int] + vision_feature_select_strategy: Final[str] + vision_feature_layer: Final[Union[int, list[int]]] + projector_hidden_act: Final[str] + image_newline_idx: Final[int] + image_new_idx: Final[int] + multimodal_projector_bias: bool = True + + +class TarsierProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": False, + }, + "images_kwargs": {}, + } + + +class TarsierProcessor(LlavaProcessor): + + def __call__( + self, + images: ImageInput = None, + text: Union[TextInput, PreTokenizedInput, list[TextInput], + list[PreTokenizedInput]] = None, + audio=None, + videos=None, + **kwargs: Unpack[TarsierProcessorKwargs], + ) -> BatchFeature: + if images is None and text is None: + raise ValueError( + "You have to specify at least one of `images` or `text`.") + + # check if images and text inputs are reversed for BC + images, text = _validate_images_text_input_order(images, text) + + output_kwargs = self._merge_kwargs( + TarsierProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + if images is not None: + image_inputs = self.image_processor( + images, **output_kwargs["images_kwargs"]) + else: + image_inputs = {} + + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise ValueError("Invalid input text. Please provide a string," + " or a list of strings") + + # try to expand inputs in processing if we have the necessary parts + prompt_strings = text + if image_inputs.get("pixel_values") is not None: + # Replace the image token with the expanded image token sequence + pixel_values = image_inputs["pixel_values"] + height, width = get_image_size(to_numpy_array(pixel_values[0])) + num_image_tokens = (height // self.patch_size) * ( + width // self.patch_size + + 1) + self.num_additional_image_tokens + 1 + if self.vision_feature_select_strategy == "default": + num_image_tokens -= 1 + + prompt_strings = [] + for sample in text: + sample = sample.replace(self.image_token, + self.image_token * num_image_tokens) + prompt_strings.append(sample) + + return_tensors = output_kwargs["text_kwargs"].pop( + "return_tensors", None) + text_inputs = self.tokenizer(prompt_strings, + **output_kwargs["text_kwargs"]) + return BatchFeature(data={ + **text_inputs, + **image_inputs + }, + tensor_type=return_tensors) + + +class TarsierMultiModalProjector(nn.Module): + + def __init__(self, + vision_hidden_size: int, + text_hidden_size: int, + projector_hidden_act: str, + multimodal_projector_bias: bool, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + + self.linear_1 = ColumnParallelLinear(vision_hidden_size, + text_hidden_size, + bias=multimodal_projector_bias, + quant_config=quant_config, + prefix=f"{prefix}.linear_1") + self.act = get_act_fn(projector_hidden_act) + self.linear_2 = RowParallelLinear(text_hidden_size, + text_hidden_size, + bias=multimodal_projector_bias, + quant_config=quant_config, + prefix=f"{prefix}.linear_2") + + def forward(self, image_features: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.linear_2(hidden_states) + return hidden_states + + +class TarsierProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self) -> TarsierHfConfig: + return self.ctx.get_hf_config(HfLlavaConfig) + + def get_vision_encoder_info(self) -> VisionEncoderInfo: + return get_vision_encoder_info(self.get_hf_config()) + + def get_hf_processor(self, **kwargs: object) -> TarsierProcessor: + hf_processor = self.ctx.get_hf_processor(TarsierProcessor, **kwargs) + # Patch for patch_size if needed (copied from vLLM LLaVA) + if hasattr(hf_processor, + 'patch_size') and hf_processor.patch_size is None: + patch_size = self.get_vision_encoder_info().get_patch_size() + hf_processor.patch_size = patch_size + return hf_processor + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def _apply_feature_select_strategy( + self, + strategy: str, + encoder_num_image_tokens: int, + ) -> int: + if strategy == "default": + return encoder_num_image_tokens - 1 + if strategy == "full": + return encoder_num_image_tokens + msg = f"Unexpected feature select strategy: {strategy!r}" + raise NotImplementedError(msg) + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + hf_config = self.get_hf_config() + vision_encoder_info = self.get_vision_encoder_info() + num_projected_patches = self._apply_feature_select_strategy( + hf_config.vision_feature_select_strategy, + vision_encoder_info.get_num_image_tokens( + image_width=image_width, + image_height=image_height, + ), + ) + if num_projected_patches <= 0: + default_size = self.get_image_size_with_most_features() + num_projected_patches_default = self._apply_feature_select_strategy( + hf_config.vision_feature_select_strategy, + vision_encoder_info.get_num_image_tokens( + image_width=default_size.width, + image_height=default_size.height, + ), + ) + if num_projected_patches_default <= 0: + raise ValueError( + "Could not determine a valid number of image patches.") + num_projected_patches = num_projected_patches_default + num_height_patches = int(math.sqrt(num_projected_patches)) + total_image_tokens_for_llm = num_projected_patches \ + + num_height_patches + 1 + return total_image_tokens_for_llm + + def get_image_size_with_most_features(self) -> ImageSize: + vision_encoder_info = self.get_vision_encoder_info() + width = height = vision_encoder_info.get_image_size() + return ImageSize(width=width, height=height) + + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() + return self.get_num_image_tokens( + image_width=target_width, + image_height=target_height, + ) + + def get_image_newline_idx(self) -> int: + return self.get_hf_config().image_newline_idx + + def get_image_new_idx(self) -> int: + return self.get_hf_config().image_new_idx + + +_I_Tarsier = TypeVar("_I_Tarsier", bound=TarsierProcessingInfo) + + +class TarsierDummyInputsBuilder(LlavaDummyInputsBuilder[_I_Tarsier]): + + pass + + +class TarsierMultiModalProcessor(BaseMultiModalProcessor[_I_Tarsier]): + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_config = self.info.get_hf_config() + image_token_id = hf_config.image_token_index # The token ID + + def get_replacement(item_idx: int): + images = mm_items.get_items( + "image", (ImageEmbeddingItems, ImageProcessorItems)) + + if isinstance(images, ImageEmbeddingItems): + num_projected_patches = images.get_feature_size(item_idx) + # This assumes num_projected_patches is a perfect square + num_height_patches = int(math.sqrt(num_projected_patches)) + num_final_image_tokens = num_projected_patches \ + + num_height_patches + 1 + else: + image_size = images.get_image_size(item_idx) + num_final_image_tokens = self.info.get_num_image_tokens( + image_width=image_size.width, + image_height=image_size.height, + ) + + return [image_token_id] * num_final_image_tokens + + return [ + PromptReplacement( + modality="image", + target=[image_token_id], # Replace each single token + replacement=get_replacement, + ), + ] + + +def _build_tarsier_hf_info( + ctx: InputProcessingContext) -> TarsierProcessingInfo: + return TarsierProcessingInfo(ctx) + + +def _build_tarsier_hf_processor( + info: _I_Tarsier, + dummy_inputs: BaseDummyInputsBuilder[_I_Tarsier], + *, + cache: Optional[ProcessingCache] = None, +) -> BaseMultiModalProcessor: + if isinstance(info, TarsierProcessingInfo): + return TarsierMultiModalProcessor( + info, + dummy_inputs, + cache=cache, + ) + raise NotImplementedError(type(info)) + + +def init_vision_tower_for_tarsier( + hf_config: TarsierHfConfig, # Use the Tarsier specific config protocol + quant_config: Optional[QuantizationConfig], + *, + require_post_norm: Optional[bool] = None, + prefix: str = "", +) -> Union[CLIPVisionModel, SiglipVisionModel]: + vision_config = hf_config.vision_config + + feature_layers = hf_config.vision_feature_layer + base_num_hidden_layers = vision_config.num_hidden_layers + + def _get_layer_index(feature_layer_index: int, + num_hidden_layers_total: int) -> int: + if feature_layer_index < 0: + return num_hidden_layers_total + feature_layer_index + 1 + return feature_layer_index + + if isinstance(feature_layers, int): + num_hidden_layers_to_init = _get_layer_index(feature_layers, + base_num_hidden_layers) + elif isinstance(feature_layers, (list, tuple)): + num_hidden_layers_to_init = max( + _get_layer_index(idx, base_num_hidden_layers) + for idx in feature_layers) + else: + raise TypeError(f"vision_layer_feature type: {type(feature_layers)}" + " is not supported") + + if isinstance(vision_config, CLIPVisionConfig): + return CLIPVisionModel( + vision_config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers_to_init, + require_post_norm=require_post_norm, + prefix=prefix, + ) + elif isinstance(vision_config, SiglipVisionConfig): + return SiglipVisionModel( + vision_config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers_to_init, + require_post_norm=require_post_norm, + prefix=prefix, + ) + + msg = f"Unsupported vision config for Tarsier: {type(vision_config)}" + raise NotImplementedError(msg) + + +@MULTIMODAL_REGISTRY.register_processor(_build_tarsier_hf_processor, + info=_build_tarsier_hf_info, + dummy_inputs=TarsierDummyInputsBuilder) +class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"] + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + config: TarsierHfConfig = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config # Storing the Tarsier-specific HF config + self.vision_tower = init_vision_tower_for_tarsier( + config, + quant_config, + require_post_norm=False, + prefix=maybe_prefix(prefix, "vision_tower")) + projector_bias = getattr(config, "multimodal_projector_bias", True) + + self.multi_modal_projector = TarsierMultiModalProjector( + vision_hidden_size=config.vision_config.hidden_size, + text_hidden_size=config.text_config.hidden_size, + projector_hidden_act=config.projector_hidden_act, + multimodal_projector_bias=projector_bias, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "multi_modal_projector")) + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config. + text_config, # Use text_config from Tarsier's main config + prefix=maybe_prefix(prefix, "language_model"), + ) + self.register_buffer('image_newline_idx_tensor', + torch.tensor([config.image_newline_idx], + dtype=torch.long), + persistent=False) + self.register_buffer('image_new_idx_tensor', + torch.tensor([config.image_new_idx], + dtype=torch.long), + persistent=False) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + h = w = self.config.vision_config.image_size + expected_dims = (3, h, w) # Assuming 3 channels + actual_dims = tuple(data.shape[1:]) + + if actual_dims != expected_dims: + expected_expr = ("batch_size", *map(str, expected_dims)) + raise ValueError( + f"The expected shape of pixel values is {expected_expr}. " + f"You supplied {tuple(data.shape)}.") + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[TarsierImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + return TarsierImagePixelInputs( + type="pixel_values", + pixel_values=self._validate_pixel_values( + flatten_bn(pixel_values, concat=True)), + ) + + if image_embeds is not None: + if not isinstance(image_embeds, (torch.Tensor, list)): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + return TarsierImageEmbeddingInputs( + type="image_embeds", + data=flatten_bn(image_embeds, concat=True), + ) + + raise AssertionError("This line should be unreachable.") + + def _select_image_features(self, image_features: torch.Tensor, *, + strategy: str) -> torch.Tensor: + if strategy == "default": + return image_features[:, 1:] + elif strategy == "full": + return image_features + raise ValueError(f"Unexpected select feature strategy: {strategy}") + + def _image_pixels_to_features( + self, + vision_tower: Union[CLIPVisionModel, SiglipVisionModel], + pixel_values: Union[torch.Tensor, list[torch.Tensor]], + ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + # From vLLM LLaVA, vision tower output handling + image_hidden_states = vision_tower(pixel_values) + if not isinstance(image_hidden_states, torch.Tensor): + raise TypeError( + f"image_hidden_states type: {type(image_hidden_states)}" + " is not supported") + + def select_features_fn(leaf: torch.Tensor): + return self._select_image_features( + leaf, + strategy=self.config.vision_feature_select_strategy, + ) + + selected_features = cast( + Union[torch.Tensor, tuple[torch.Tensor, ...]], + json_map_leaves(select_features_fn, image_hidden_states), + ) + return selected_features + + def _add_tarsier_split_tokens( + self, projected_image_features: torch.Tensor) -> torch.Tensor: + """ + Implements Tarsier's `add_split_tokens` logic. + """ + num_images, num_projected_patches, embed_dim = \ + projected_image_features.shape + num_height_patches = int(math.sqrt(num_projected_patches)) + num_width_patches = num_projected_patches // num_height_patches + device = projected_image_features.device + embedding_layer = self.language_model.model.embed_tokens + image_newline_emb = embedding_layer( + self.image_newline_idx_tensor.to(device)).squeeze(0) + image_new_emb = embedding_layer( + self.image_new_idx_tensor.to(device)).squeeze(0) + try: + current_image_features_grid = projected_image_features.view( + num_images, num_height_patches, num_width_patches, embed_dim) + except RuntimeError as e: + raise RuntimeError( + "Cannot reshape projected_image_features" + f" with shape {projected_image_features.shape} " + f"to ({num_images}, {num_height_patches}," + f" {num_width_patches}, {embed_dim}). " + "Ensure num_projected_patches is compatible" + " with a grid structure. " + f"num_projected_patches={num_projected_patches}, " + f"derived num_height_patches={num_height_patches}. ") from e + + image_newline_expanded = image_newline_emb.expand( + (num_images, num_height_patches, 1, embed_dim)) + features_with_newlines = torch.cat( + [current_image_features_grid, image_newline_expanded], + dim=2 # Concatenate along width dim + ) + new_num_patches_after_newline = num_projected_patches \ + + num_height_patches + features_with_newlines_flat = features_with_newlines.view( + num_images, new_num_patches_after_newline, embed_dim) + image_new_expanded = image_new_emb.expand((num_images, 1, embed_dim)) + final_image_features = torch.cat( + [features_with_newlines_flat, image_new_expanded], + dim=1 # Concatenate along patch sequence dim + ) + return final_image_features + + def _process_image_pixels( + self, + inputs: TarsierImagePixelInputs, + ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + assert self.vision_tower is not None + pixel_values = inputs["pixel_values"] + image_features_selected = self._image_pixels_to_features( + self.vision_tower, pixel_values) # type: ignore + if isinstance(image_features_selected, torch.Tensor): + projected_features = self.multi_modal_projector( + image_features_selected) + final_features = self._add_tarsier_split_tokens(projected_features) + return final_features + else: + raise TypeError( + f"_image_pixels_to_features type:" + f" {type(image_features_selected)} is not supported") + + def _process_image_input( + self, + image_input: TarsierImageInputs, + ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + if image_input["type"] == "image_embeds": + projected_features = image_input["data"] + if isinstance(projected_features, torch.Tensor): + return self._add_tarsier_split_tokens(projected_features) + else: + raise ValueError("Incorrect type of image_embeds. " + f"Got type: {type(projected_features)}. ") + assert self.vision_tower is not None + return self._process_image_pixels(image_input) + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + return self._process_image_input(image_input) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + multimodal_embeddings, + self.config.image_token_index, + ) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + hidden_states = self.language_model.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights)