mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-22 06:41:21 +08:00
Add tarsier model support (#18985)
Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
This commit is contained in:
parent
bdce64f236
commit
1282bd812e
@ -550,6 +550,7 @@ Specified using `--task generate`.
|
||||
| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen2.5-Omni-7B` | | ✅︎ | ✅︎\* |
|
||||
| `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ |
|
||||
| `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ |
|
||||
| `TarsierForConditionalGeneration` | Tarsier | T + I<sup>E+</sup> | `omni-search/Tarsier-7b`,`omni-search/Tarsier-34b` | | ✅︎ | ✅︎ |
|
||||
|
||||
<sup>^</sup> You need to set the architecture name via `--hf-overrides` to match the one in vLLM.
|
||||
• For example, to use DeepSeek-VL2 series models:
|
||||
|
||||
@ -333,6 +333,25 @@ def run_smolvlm(questions: list[str], modality: str) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
# omni-research/Tarsier-7b
|
||||
def run_tarsier(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
model_name = "omni-research/Tarsier-7b"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_model_len=4096,
|
||||
limit_mm_per_prompt={modality: 1},
|
||||
)
|
||||
prompts = [(f"USER: <image>\n{question} ASSISTANT:") for question in questions]
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# InternVL
|
||||
def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
model_name = "OpenGVLab/InternVL3-2B"
|
||||
@ -1091,6 +1110,7 @@ model_example_map = {
|
||||
"qwen2_5_omni": run_qwen2_5_omni,
|
||||
"skywork_chat": run_skyworkr1v,
|
||||
"smolvlm": run_smolvlm,
|
||||
"tarsier": run_tarsier,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -691,6 +691,26 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
def load_tarsier(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "omni-research/Tarsier-7b"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_model_len=4096,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
)
|
||||
|
||||
prompt = f"USER: {'<image>' * len(image_urls)}\n{question}\n ASSISTANT:"
|
||||
image_data = [fetch_image(url) for url in image_urls]
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
image_data=image_data,
|
||||
)
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"aria": load_aria,
|
||||
"aya_vision": load_aya_vision,
|
||||
@ -712,6 +732,7 @@ model_example_map = {
|
||||
"qwen2_vl": load_qwen2_vl,
|
||||
"qwen2_5_vl": load_qwen2_5_vl,
|
||||
"smolvlm": load_smolvlm,
|
||||
"tarsier": load_tarsier,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -282,6 +282,7 @@ def _test_processing_correctness_one(
|
||||
"Skywork/Skywork-R1V-38B",
|
||||
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
|
||||
"openai/whisper-large-v3",
|
||||
"omni-research/Tarsier-7b",
|
||||
])
|
||||
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
|
||||
@pytest.mark.parametrize("num_batches", [32])
|
||||
|
||||
@ -406,6 +406,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
"SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501
|
||||
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501
|
||||
trust_remote_code=True),
|
||||
"TarsierForConditionalGeneration": _HfExamplesInfo("omni-research/Tarsier-7b", # noqa: E501
|
||||
hf_overrides={"architectures": ["TarsierForConditionalGeneration"]}), # noqa: E501
|
||||
# [Encoder-decoder]
|
||||
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
|
||||
# Therefore, we borrow the BartTokenizer from the original Bart model
|
||||
|
||||
@ -211,6 +211,7 @@ _MULTIMODAL_MODELS = {
|
||||
"Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501
|
||||
"UltravoxModel": ("ultravox", "UltravoxModel"),
|
||||
"Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
|
||||
"TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501
|
||||
# [Encoder-decoder]
|
||||
"Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501
|
||||
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
|
||||
|
||||
643
vllm/model_executor/models/tarsier.py
Normal file
643
vllm/model_executor/models/tarsier.py
Normal file
@ -0,0 +1,643 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar,
|
||||
Union, cast)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import BatchFeature, CLIPVisionConfig
|
||||
from transformers import LlavaConfig as HfLlavaConfig
|
||||
from transformers import PretrainedConfig, SiglipVisionConfig
|
||||
from transformers.image_utils import ImageInput, get_image_size, to_numpy_array
|
||||
from transformers.models.llava import LlavaProcessor
|
||||
from transformers.processing_utils import (ProcessingKwargs, Unpack,
|
||||
_validate_images_text_input_order)
|
||||
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import InputProcessingContext
|
||||
from vllm.jsontree import json_map_leaves
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.models.llava import LlavaDummyInputsBuilder
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
|
||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
ImageSize, MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, ProcessingCache,
|
||||
PromptReplacement, PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .clip import CLIPVisionModel
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .vision import VisionEncoderInfo, get_vision_encoder_info
|
||||
|
||||
|
||||
class TarsierImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
pixel_values: torch.Tensor
|
||||
|
||||
|
||||
class TarsierImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
data: torch.Tensor
|
||||
|
||||
|
||||
TarsierImageInputs = Union[TarsierImagePixelInputs,
|
||||
TarsierImageEmbeddingInputs]
|
||||
|
||||
|
||||
class TarsierHfConfig(Protocol): # Based on the Tarsier's LlavaConfig
|
||||
vision_config: Final[PretrainedConfig]
|
||||
text_config: Final[PretrainedConfig] # Added from Tarsier's LlavaConfig
|
||||
image_token_index: Final[int]
|
||||
vision_feature_select_strategy: Final[str]
|
||||
vision_feature_layer: Final[Union[int, list[int]]]
|
||||
projector_hidden_act: Final[str]
|
||||
image_newline_idx: Final[int]
|
||||
image_new_idx: Final[int]
|
||||
multimodal_projector_bias: bool = True
|
||||
|
||||
|
||||
class TarsierProcessorKwargs(ProcessingKwargs, total=False):
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"padding": False,
|
||||
},
|
||||
"images_kwargs": {},
|
||||
}
|
||||
|
||||
|
||||
class TarsierProcessor(LlavaProcessor):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images: ImageInput = None,
|
||||
text: Union[TextInput, PreTokenizedInput, list[TextInput],
|
||||
list[PreTokenizedInput]] = None,
|
||||
audio=None,
|
||||
videos=None,
|
||||
**kwargs: Unpack[TarsierProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
if images is None and text is None:
|
||||
raise ValueError(
|
||||
"You have to specify at least one of `images` or `text`.")
|
||||
|
||||
# check if images and text inputs are reversed for BC
|
||||
images, text = _validate_images_text_input_order(images, text)
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
TarsierProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
if images is not None:
|
||||
image_inputs = self.image_processor(
|
||||
images, **output_kwargs["images_kwargs"])
|
||||
else:
|
||||
image_inputs = {}
|
||||
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||
raise ValueError("Invalid input text. Please provide a string,"
|
||||
" or a list of strings")
|
||||
|
||||
# try to expand inputs in processing if we have the necessary parts
|
||||
prompt_strings = text
|
||||
if image_inputs.get("pixel_values") is not None:
|
||||
# Replace the image token with the expanded image token sequence
|
||||
pixel_values = image_inputs["pixel_values"]
|
||||
height, width = get_image_size(to_numpy_array(pixel_values[0]))
|
||||
num_image_tokens = (height // self.patch_size) * (
|
||||
width // self.patch_size +
|
||||
1) + self.num_additional_image_tokens + 1
|
||||
if self.vision_feature_select_strategy == "default":
|
||||
num_image_tokens -= 1
|
||||
|
||||
prompt_strings = []
|
||||
for sample in text:
|
||||
sample = sample.replace(self.image_token,
|
||||
self.image_token * num_image_tokens)
|
||||
prompt_strings.append(sample)
|
||||
|
||||
return_tensors = output_kwargs["text_kwargs"].pop(
|
||||
"return_tensors", None)
|
||||
text_inputs = self.tokenizer(prompt_strings,
|
||||
**output_kwargs["text_kwargs"])
|
||||
return BatchFeature(data={
|
||||
**text_inputs,
|
||||
**image_inputs
|
||||
},
|
||||
tensor_type=return_tensors)
|
||||
|
||||
|
||||
class TarsierMultiModalProjector(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
vision_hidden_size: int,
|
||||
text_hidden_size: int,
|
||||
projector_hidden_act: str,
|
||||
multimodal_projector_bias: bool,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = ColumnParallelLinear(vision_hidden_size,
|
||||
text_hidden_size,
|
||||
bias=multimodal_projector_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.linear_1")
|
||||
self.act = get_act_fn(projector_hidden_act)
|
||||
self.linear_2 = RowParallelLinear(text_hidden_size,
|
||||
text_hidden_size,
|
||||
bias=multimodal_projector_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.linear_2")
|
||||
|
||||
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.linear_1(image_features)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states, _ = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TarsierProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_hf_config(self) -> TarsierHfConfig:
|
||||
return self.ctx.get_hf_config(HfLlavaConfig)
|
||||
|
||||
def get_vision_encoder_info(self) -> VisionEncoderInfo:
|
||||
return get_vision_encoder_info(self.get_hf_config())
|
||||
|
||||
def get_hf_processor(self, **kwargs: object) -> TarsierProcessor:
|
||||
hf_processor = self.ctx.get_hf_processor(TarsierProcessor, **kwargs)
|
||||
# Patch for patch_size if needed (copied from vLLM LLaVA)
|
||||
if hasattr(hf_processor,
|
||||
'patch_size') and hf_processor.patch_size is None:
|
||||
patch_size = self.get_vision_encoder_info().get_patch_size()
|
||||
hf_processor.patch_size = patch_size
|
||||
return hf_processor
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
|
||||
def _apply_feature_select_strategy(
|
||||
self,
|
||||
strategy: str,
|
||||
encoder_num_image_tokens: int,
|
||||
) -> int:
|
||||
if strategy == "default":
|
||||
return encoder_num_image_tokens - 1
|
||||
if strategy == "full":
|
||||
return encoder_num_image_tokens
|
||||
msg = f"Unexpected feature select strategy: {strategy!r}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
hf_config = self.get_hf_config()
|
||||
vision_encoder_info = self.get_vision_encoder_info()
|
||||
num_projected_patches = self._apply_feature_select_strategy(
|
||||
hf_config.vision_feature_select_strategy,
|
||||
vision_encoder_info.get_num_image_tokens(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
),
|
||||
)
|
||||
if num_projected_patches <= 0:
|
||||
default_size = self.get_image_size_with_most_features()
|
||||
num_projected_patches_default = self._apply_feature_select_strategy(
|
||||
hf_config.vision_feature_select_strategy,
|
||||
vision_encoder_info.get_num_image_tokens(
|
||||
image_width=default_size.width,
|
||||
image_height=default_size.height,
|
||||
),
|
||||
)
|
||||
if num_projected_patches_default <= 0:
|
||||
raise ValueError(
|
||||
"Could not determine a valid number of image patches.")
|
||||
num_projected_patches = num_projected_patches_default
|
||||
num_height_patches = int(math.sqrt(num_projected_patches))
|
||||
total_image_tokens_for_llm = num_projected_patches \
|
||||
+ num_height_patches + 1
|
||||
return total_image_tokens_for_llm
|
||||
|
||||
def get_image_size_with_most_features(self) -> ImageSize:
|
||||
vision_encoder_info = self.get_vision_encoder_info()
|
||||
width = height = vision_encoder_info.get_image_size()
|
||||
return ImageSize(width=width, height=height)
|
||||
|
||||
def get_max_image_tokens(self) -> int:
|
||||
target_width, target_height = self.get_image_size_with_most_features()
|
||||
return self.get_num_image_tokens(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
)
|
||||
|
||||
def get_image_newline_idx(self) -> int:
|
||||
return self.get_hf_config().image_newline_idx
|
||||
|
||||
def get_image_new_idx(self) -> int:
|
||||
return self.get_hf_config().image_new_idx
|
||||
|
||||
|
||||
_I_Tarsier = TypeVar("_I_Tarsier", bound=TarsierProcessingInfo)
|
||||
|
||||
|
||||
class TarsierDummyInputsBuilder(LlavaDummyInputsBuilder[_I_Tarsier]):
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TarsierMultiModalProcessor(BaseMultiModalProcessor[_I_Tarsier]):
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_config = self.info.get_hf_config()
|
||||
image_token_id = hf_config.image_token_index # The <IMAGE> token ID
|
||||
|
||||
def get_replacement(item_idx: int):
|
||||
images = mm_items.get_items(
|
||||
"image", (ImageEmbeddingItems, ImageProcessorItems))
|
||||
|
||||
if isinstance(images, ImageEmbeddingItems):
|
||||
num_projected_patches = images.get_feature_size(item_idx)
|
||||
# This assumes num_projected_patches is a perfect square
|
||||
num_height_patches = int(math.sqrt(num_projected_patches))
|
||||
num_final_image_tokens = num_projected_patches \
|
||||
+ num_height_patches + 1
|
||||
else:
|
||||
image_size = images.get_image_size(item_idx)
|
||||
num_final_image_tokens = self.info.get_num_image_tokens(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
)
|
||||
|
||||
return [image_token_id] * num_final_image_tokens
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[image_token_id], # Replace each single <IMAGE> token
|
||||
replacement=get_replacement,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _build_tarsier_hf_info(
|
||||
ctx: InputProcessingContext) -> TarsierProcessingInfo:
|
||||
return TarsierProcessingInfo(ctx)
|
||||
|
||||
|
||||
def _build_tarsier_hf_processor(
|
||||
info: _I_Tarsier,
|
||||
dummy_inputs: BaseDummyInputsBuilder[_I_Tarsier],
|
||||
*,
|
||||
cache: Optional[ProcessingCache] = None,
|
||||
) -> BaseMultiModalProcessor:
|
||||
if isinstance(info, TarsierProcessingInfo):
|
||||
return TarsierMultiModalProcessor(
|
||||
info,
|
||||
dummy_inputs,
|
||||
cache=cache,
|
||||
)
|
||||
raise NotImplementedError(type(info))
|
||||
|
||||
|
||||
def init_vision_tower_for_tarsier(
|
||||
hf_config: TarsierHfConfig, # Use the Tarsier specific config protocol
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
*,
|
||||
require_post_norm: Optional[bool] = None,
|
||||
prefix: str = "",
|
||||
) -> Union[CLIPVisionModel, SiglipVisionModel]:
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
feature_layers = hf_config.vision_feature_layer
|
||||
base_num_hidden_layers = vision_config.num_hidden_layers
|
||||
|
||||
def _get_layer_index(feature_layer_index: int,
|
||||
num_hidden_layers_total: int) -> int:
|
||||
if feature_layer_index < 0:
|
||||
return num_hidden_layers_total + feature_layer_index + 1
|
||||
return feature_layer_index
|
||||
|
||||
if isinstance(feature_layers, int):
|
||||
num_hidden_layers_to_init = _get_layer_index(feature_layers,
|
||||
base_num_hidden_layers)
|
||||
elif isinstance(feature_layers, (list, tuple)):
|
||||
num_hidden_layers_to_init = max(
|
||||
_get_layer_index(idx, base_num_hidden_layers)
|
||||
for idx in feature_layers)
|
||||
else:
|
||||
raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
|
||||
" is not supported")
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
return CLIPVisionModel(
|
||||
vision_config,
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers_to_init,
|
||||
require_post_norm=require_post_norm,
|
||||
prefix=prefix,
|
||||
)
|
||||
elif isinstance(vision_config, SiglipVisionConfig):
|
||||
return SiglipVisionModel(
|
||||
vision_config,
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers_to_init,
|
||||
require_post_norm=require_post_norm,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
msg = f"Unsupported vision config for Tarsier: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(_build_tarsier_hf_processor,
|
||||
info=_build_tarsier_hf_info,
|
||||
dummy_inputs=TarsierDummyInputsBuilder)
|
||||
class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"]
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
config: TarsierHfConfig = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config # Storing the Tarsier-specific HF config
|
||||
self.vision_tower = init_vision_tower_for_tarsier(
|
||||
config,
|
||||
quant_config,
|
||||
require_post_norm=False,
|
||||
prefix=maybe_prefix(prefix, "vision_tower"))
|
||||
projector_bias = getattr(config, "multimodal_projector_bias", True)
|
||||
|
||||
self.multi_modal_projector = TarsierMultiModalProjector(
|
||||
vision_hidden_size=config.vision_config.hidden_size,
|
||||
text_hidden_size=config.text_config.hidden_size,
|
||||
projector_hidden_act=config.projector_hidden_act,
|
||||
multimodal_projector_bias=projector_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "multi_modal_projector"))
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=config.
|
||||
text_config, # Use text_config from Tarsier's main config
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
self.register_buffer('image_newline_idx_tensor',
|
||||
torch.tensor([config.image_newline_idx],
|
||||
dtype=torch.long),
|
||||
persistent=False)
|
||||
self.register_buffer('image_new_idx_tensor',
|
||||
torch.tensor([config.image_new_idx],
|
||||
dtype=torch.long),
|
||||
persistent=False)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
||||
h = w = self.config.vision_config.image_size
|
||||
expected_dims = (3, h, w) # Assuming 3 channels
|
||||
actual_dims = tuple(data.shape[1:])
|
||||
|
||||
if actual_dims != expected_dims:
|
||||
expected_expr = ("batch_size", *map(str, expected_dims))
|
||||
raise ValueError(
|
||||
f"The expected shape of pixel values is {expected_expr}. "
|
||||
f"You supplied {tuple(data.shape)}.")
|
||||
return data
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[TarsierImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
|
||||
if pixel_values is None and image_embeds is None:
|
||||
return None
|
||||
|
||||
if pixel_values is not None:
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
return TarsierImagePixelInputs(
|
||||
type="pixel_values",
|
||||
pixel_values=self._validate_pixel_values(
|
||||
flatten_bn(pixel_values, concat=True)),
|
||||
)
|
||||
|
||||
if image_embeds is not None:
|
||||
if not isinstance(image_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
return TarsierImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
data=flatten_bn(image_embeds, concat=True),
|
||||
)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _select_image_features(self, image_features: torch.Tensor, *,
|
||||
strategy: str) -> torch.Tensor:
|
||||
if strategy == "default":
|
||||
return image_features[:, 1:]
|
||||
elif strategy == "full":
|
||||
return image_features
|
||||
raise ValueError(f"Unexpected select feature strategy: {strategy}")
|
||||
|
||||
def _image_pixels_to_features(
|
||||
self,
|
||||
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
|
||||
pixel_values: Union[torch.Tensor, list[torch.Tensor]],
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||
# From vLLM LLaVA, vision tower output handling
|
||||
image_hidden_states = vision_tower(pixel_values)
|
||||
if not isinstance(image_hidden_states, torch.Tensor):
|
||||
raise TypeError(
|
||||
f"image_hidden_states type: {type(image_hidden_states)}"
|
||||
" is not supported")
|
||||
|
||||
def select_features_fn(leaf: torch.Tensor):
|
||||
return self._select_image_features(
|
||||
leaf,
|
||||
strategy=self.config.vision_feature_select_strategy,
|
||||
)
|
||||
|
||||
selected_features = cast(
|
||||
Union[torch.Tensor, tuple[torch.Tensor, ...]],
|
||||
json_map_leaves(select_features_fn, image_hidden_states),
|
||||
)
|
||||
return selected_features
|
||||
|
||||
def _add_tarsier_split_tokens(
|
||||
self, projected_image_features: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Implements Tarsier's `add_split_tokens` logic.
|
||||
"""
|
||||
num_images, num_projected_patches, embed_dim = \
|
||||
projected_image_features.shape
|
||||
num_height_patches = int(math.sqrt(num_projected_patches))
|
||||
num_width_patches = num_projected_patches // num_height_patches
|
||||
device = projected_image_features.device
|
||||
embedding_layer = self.language_model.model.embed_tokens
|
||||
image_newline_emb = embedding_layer(
|
||||
self.image_newline_idx_tensor.to(device)).squeeze(0)
|
||||
image_new_emb = embedding_layer(
|
||||
self.image_new_idx_tensor.to(device)).squeeze(0)
|
||||
try:
|
||||
current_image_features_grid = projected_image_features.view(
|
||||
num_images, num_height_patches, num_width_patches, embed_dim)
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError(
|
||||
"Cannot reshape projected_image_features"
|
||||
f" with shape {projected_image_features.shape} "
|
||||
f"to ({num_images}, {num_height_patches},"
|
||||
f" {num_width_patches}, {embed_dim}). "
|
||||
"Ensure num_projected_patches is compatible"
|
||||
" with a grid structure. "
|
||||
f"num_projected_patches={num_projected_patches}, "
|
||||
f"derived num_height_patches={num_height_patches}. ") from e
|
||||
|
||||
image_newline_expanded = image_newline_emb.expand(
|
||||
(num_images, num_height_patches, 1, embed_dim))
|
||||
features_with_newlines = torch.cat(
|
||||
[current_image_features_grid, image_newline_expanded],
|
||||
dim=2 # Concatenate along width dim
|
||||
)
|
||||
new_num_patches_after_newline = num_projected_patches \
|
||||
+ num_height_patches
|
||||
features_with_newlines_flat = features_with_newlines.view(
|
||||
num_images, new_num_patches_after_newline, embed_dim)
|
||||
image_new_expanded = image_new_emb.expand((num_images, 1, embed_dim))
|
||||
final_image_features = torch.cat(
|
||||
[features_with_newlines_flat, image_new_expanded],
|
||||
dim=1 # Concatenate along patch sequence dim
|
||||
)
|
||||
return final_image_features
|
||||
|
||||
def _process_image_pixels(
|
||||
self,
|
||||
inputs: TarsierImagePixelInputs,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||
assert self.vision_tower is not None
|
||||
pixel_values = inputs["pixel_values"]
|
||||
image_features_selected = self._image_pixels_to_features(
|
||||
self.vision_tower, pixel_values) # type: ignore
|
||||
if isinstance(image_features_selected, torch.Tensor):
|
||||
projected_features = self.multi_modal_projector(
|
||||
image_features_selected)
|
||||
final_features = self._add_tarsier_split_tokens(projected_features)
|
||||
return final_features
|
||||
else:
|
||||
raise TypeError(
|
||||
f"_image_pixels_to_features type:"
|
||||
f" {type(image_features_selected)} is not supported")
|
||||
|
||||
def _process_image_input(
|
||||
self,
|
||||
image_input: TarsierImageInputs,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||
if image_input["type"] == "image_embeds":
|
||||
projected_features = image_input["data"]
|
||||
if isinstance(projected_features, torch.Tensor):
|
||||
return self._add_tarsier_split_tokens(projected_features)
|
||||
else:
|
||||
raise ValueError("Incorrect type of image_embeds. "
|
||||
f"Got type: {type(projected_features)}. ")
|
||||
assert self.vision_tower is not None
|
||||
return self._process_image_pixels(image_input)
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return None
|
||||
return self._process_image_input(image_input)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
self.config.image_token_index,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
input_ids = None
|
||||
hidden_states = self.language_model.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
return self.language_model.compute_logits(hidden_states,
|
||||
sampling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
Loading…
x
Reference in New Issue
Block a user