diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 2bffb28d31934..4e43db5b88fad 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -800,12 +800,13 @@ These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) A The following table lists those that are tested in vLLM. -| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | -|--------------|--------|--------|-------------------|----------------------|---------------------------| -| `CLIPModel` | CLIP | T / I | `openai/clip-vit-base-patch32`, `openai/clip-vit-large-patch14`, etc. | | | -| `LlavaNextForConditionalGeneration`C | LLaVA-NeXT-based | T / I | `royokong/e5-v` | | ✅︎ | -| `Phi3VForCausalLM`C | Phi-3-Vision-based | T + I | `TIGER-Lab/VLM2Vec-Full` | | ✅︎ | -| `*ForConditionalGeneration`C, `*ForCausalLM`C, etc. | Generative models | \* | N/A | \* | \* | +| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | +|--------------|--------|--------|-------------------|----------------------|---------------------------|---------------------| +| `CLIPModel` | CLIP | T / I | `openai/clip-vit-base-patch32`, `openai/clip-vit-large-patch14`, etc. | | | ✅︎ | +| `LlavaNextForConditionalGeneration`C | LLaVA-NeXT-based | T / I | `royokong/e5-v` | | ✅︎ | ✅︎ | +| `Phi3VForCausalLM`C | Phi-3-Vision-based | T + I | `TIGER-Lab/VLM2Vec-Full` | | ✅︎ | ✅︎ | +| `SiglipModel` | SigLIP | T / I | `google/siglip-base-patch16-224` | | | ✅︎ | +| `*ForConditionalGeneration`C, `*ForCausalLM`C, etc. | Generative models | \* | N/A | \* | \* | \* | C Automatically converted into an embedding model via `--convert embed`. ([details](./pooling_models.md#model-conversion)) \* Feature support is the same as that of the original model. diff --git a/examples/offline_inference/vision_language_pooling.py b/examples/offline_inference/vision_language_pooling.py index 1ce2cdc436d6a..cf4695c2545fb 100644 --- a/examples/offline_inference/vision_language_pooling.py +++ b/examples/offline_inference/vision_language_pooling.py @@ -110,6 +110,53 @@ def run_e5_v(query: Query) -> ModelRequestData: ) +def run_jinavl_reranker(query: Query) -> ModelRequestData: + if query["modality"] != "text+images": + raise ValueError(f"Unsupported query modality: '{query['modality']}'") + + engine_args = EngineArgs( + model="jinaai/jina-reranker-m0", + runner="pooling", + max_model_len=32768, + trust_remote_code=True, + mm_processor_kwargs={ + "min_pixels": 3136, + "max_pixels": 602112, + }, + limit_mm_per_prompt={"image": 1}, + ) + + return ModelRequestData( + engine_args=engine_args, + query=query["text"], + documents=query["image"], + ) + + +def run_siglip(query: Query) -> ModelRequestData: + if query["modality"] == "text": + prompt = query["text"] + image = None + elif query["modality"] == "image": + prompt = "" # For image input, make sure that the prompt text is empty + image = query["image"] + else: + modality = query["modality"] + raise ValueError(f"Unsupported query modality: '{modality}'") + + engine_args = EngineArgs( + model="google/siglip-base-patch16-224", + runner="pooling", + limit_mm_per_prompt={"image": 1}, + ) + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image=image, + ) + + def _get_vlm2vec_prompt_image(query: Query, image_token: str): if query["modality"] == "text": text = query["text"] @@ -211,29 +258,6 @@ def run_vlm2vec_qwen2vl(query: Query) -> ModelRequestData: ) -def run_jinavl_reranker(query: Query) -> ModelRequestData: - if query["modality"] != "text+images": - raise ValueError(f"Unsupported query modality: '{query['modality']}'") - - engine_args = EngineArgs( - model="jinaai/jina-reranker-m0", - runner="pooling", - max_model_len=32768, - trust_remote_code=True, - mm_processor_kwargs={ - "min_pixels": 3136, - "max_pixels": 602112, - }, - limit_mm_per_prompt={"image": 1}, - ) - - return ModelRequestData( - engine_args=engine_args, - query=query["text"], - documents=query["image"], - ) - - def get_query(modality: QueryModality): if modality == "text": return TextQuery(modality="text", text="A dog sitting in the grass") @@ -328,9 +352,10 @@ def run_score(model: str, modality: QueryModality, seed: int | None): model_example_map = { "clip": run_clip, "e5_v": run_e5_v, + "jinavl_reranker": run_jinavl_reranker, + "siglip": run_siglip, "vlm2vec_phi3v": run_vlm2vec_phi3v, "vlm2vec_qwen2vl": run_vlm2vec_qwen2vl, - "jinavl_reranker": run_jinavl_reranker, } diff --git a/examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py b/examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py index 25ab865a4ee43..261b810ce5d03 100644 --- a/examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py +++ b/examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py @@ -83,6 +83,109 @@ def run_clip(client: OpenAI, model: str): print("Text embedding output:", response.data[0].embedding) +def run_dse_qwen2_vl(client: OpenAI, model: str): + """ + Start the server using: + + vllm serve MrLight/dse-qwen2-2b-mrl-v1 \ + --runner pooling \ + --trust-remote-code \ + --max-model-len 8192 \ + --chat-template examples/template_dse_qwen2_vl.jinja + """ + response = create_chat_embeddings( + client, + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url, + }, + }, + {"type": "text", "text": "What is shown in this image?"}, + ], + } + ], + model=model, + encoding_format="float", + ) + + print("Image embedding output:", response.data[0].embedding) + + # MrLight/dse-qwen2-2b-mrl-v1 requires a placeholder image + # of the minimum input size + buffer = io.BytesIO() + image_placeholder = Image.new("RGB", (56, 56)) + image_placeholder.save(buffer, "png") + buffer.seek(0) + image_placeholder = base64.b64encode(buffer.read()).decode("utf-8") + response = create_chat_embeddings( + client, + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_placeholder}", + }, + }, + {"type": "text", "text": "Query: What is the weather like today?"}, + ], + } + ], + model=model, + encoding_format="float", + ) + + print("Text embedding output:", response.data[0].embedding) + + +def run_siglip(client: OpenAI, model: str): + """ + Start the server using: + + vllm serve google/siglip-base-patch16-224 \ + --runner pooling + """ + + response = create_chat_embeddings( + client, + messages=[ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + ], + } + ], + model=model, + encoding_format="float", + ) + + print("Image embedding output:", response.data[0].embedding) + + response = create_chat_embeddings( + client, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "a photo of a cat"}, + ], + } + ], + model=model, + encoding_format="float", + ) + + print("Text embedding output:", response.data[0].embedding) + + def run_vlm2vec(client: OpenAI, model: str): """ Start the server using: @@ -148,72 +251,11 @@ def run_vlm2vec(client: OpenAI, model: str): print("Text embedding output:", response.data[0].embedding) -def run_dse_qwen2_vl(client: OpenAI, model: str): - """ - Start the server using: - - vllm serve MrLight/dse-qwen2-2b-mrl-v1 \ - --runner pooling \ - --trust-remote-code \ - --max-model-len 8192 \ - --chat-template examples/template_dse_qwen2_vl.jinja - """ - response = create_chat_embeddings( - client, - messages=[ - { - "role": "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url, - }, - }, - {"type": "text", "text": "What is shown in this image?"}, - ], - } - ], - model=model, - encoding_format="float", - ) - - print("Image embedding output:", response.data[0].embedding) - - # MrLight/dse-qwen2-2b-mrl-v1 requires a placeholder image - # of the minimum input size - buffer = io.BytesIO() - image_placeholder = Image.new("RGB", (56, 56)) - image_placeholder.save(buffer, "png") - buffer.seek(0) - image_placeholder = base64.b64encode(buffer.read()).decode("utf-8") - response = create_chat_embeddings( - client, - messages=[ - { - "role": "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{image_placeholder}", - }, - }, - {"type": "text", "text": "Query: What is the weather like today?"}, - ], - } - ], - model=model, - encoding_format="float", - ) - - print("Text embedding output:", response.data[0].embedding) - - model_example_map = { "clip": run_clip, - "vlm2vec": run_vlm2vec, "dse_qwen2_vl": run_dse_qwen2_vl, + "siglip": run_siglip, + "vlm2vec": run_vlm2vec, } diff --git a/tests/models/multimodal/pooling/test_siglip.py b/tests/models/multimodal/pooling/test_siglip.py new file mode 100644 index 0000000000000..f681b4787b697 --- /dev/null +++ b/tests/models/multimodal/pooling/test_siglip.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +from transformers import SiglipModel + +from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner +from ...utils import check_embeddings_close + +HF_TEXT_PROMPTS = [ + "a photo of a stop sign", + "a photo of a cherry blossom", +] + +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts( + { + "stop_sign": "", + "cherry_blossom": "", + } +) + +MODELS = ["google/siglip-base-patch16-224"] + + +def _run_test( + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + input_texts: list[str], + input_images: PromptImageInput, + model: str, + *, + dtype: str, +) -> None: + with vllm_runner( + model, runner="pooling", dtype=dtype, enforce_eager=True, max_model_len=64 + ) as vllm_model: + vllm_outputs = vllm_model.embed(input_texts, images=input_images) + + with hf_runner(model, dtype=dtype, auto_cls=SiglipModel) as hf_model: + all_inputs = hf_model.get_inputs(input_texts, images=input_images) + + all_outputs = [] + for inputs in all_inputs: + inputs = hf_model.wrap_device(inputs) + + if "pixel_values" in inputs: + pooled_output = hf_model.model.get_image_features( + pixel_values=inputs.pixel_values, + ).squeeze(0) + else: + pooled_output = hf_model.model.get_text_features( + input_ids=inputs.input_ids, + ).squeeze(0) + + all_outputs.append(pooled_output.tolist()) + + hf_outputs = all_outputs + + check_embeddings_close( + embeddings_0_lst=hf_outputs, + embeddings_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_models_text( + hf_runner, + vllm_runner, + image_assets, + model: str, + dtype: str, +) -> None: + input_texts_images = [(text, None) for text in HF_TEXT_PROMPTS] + input_texts = [text for text, _ in input_texts_images] + input_images = [image for _, image in input_texts_images] + + _run_test( + hf_runner, + vllm_runner, + input_texts, + input_images, # type: ignore + model, + dtype=dtype, + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_models_image( + hf_runner, + vllm_runner, + image_assets, + model: str, + dtype: str, +) -> None: + input_texts_images = [ + (text, asset.pil_image) for text, asset in zip(HF_IMAGE_PROMPTS, image_assets) + ] + input_texts = [text for text, _ in input_texts_images] + input_images = [image for _, image in input_texts_images] + + _run_test( + hf_runner, + vllm_runner, + input_texts, + input_images, + model, + dtype=dtype, + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_models_text_image_no_crash( + vllm_runner, + image_assets, + model: str, + dtype: str, +) -> None: + texts = [HF_TEXT_PROMPTS[0]] + images = [image_assets[0].pil_image] + + with vllm_runner( + model, + runner="pooling", + dtype=dtype, + enforce_eager=True, + max_model_len=64, + ) as vllm_model: + with pytest.raises(ValueError, match="not both"): + vllm_model.embed(texts, images=images) + + vllm_model.embed(texts) + vllm_model.embed([""], images=images) diff --git a/tests/models/registry.py b/tests/models/registry.py index f5072ef7cbadc..8e11ee755bf7b 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -471,6 +471,7 @@ _EMBEDDING_EXAMPLE_MODELS = { "TIGER-Lab/VLM2Vec-Full", trust_remote_code=True ), "Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), + "SiglipModel": _HfExamplesInfo("google/siglip-base-patch16-224"), "PrithviGeoSpatialMAE": _HfExamplesInfo( "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", dtype="float16", diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 9879c5ba58015..81d4a6bc5f3a7 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -209,6 +209,7 @@ _EMBEDDING_MODELS = { ), "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 + "SiglipModel": ("siglip", "SiglipEmbeddingModel"), # Technically Terratorch models work on images, both in # input and output. I am adding it here because it piggy-backs on embedding # models for the time being. diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index b79dc31cfe3d4..694e06f9fc811 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -4,13 +4,23 @@ within a vision language model.""" import math -from collections.abc import Iterable +from collections.abc import Iterable, Mapping +from functools import cached_property +from typing import Annotated, Literal import torch from torch import nn -from transformers import SiglipVisionConfig +from transformers import ( + BatchFeature, + SiglipConfig, + SiglipProcessor, + SiglipTextConfig, + SiglipVisionConfig, +) from vllm.attention.layer import MultiHeadAttention +from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import ( @@ -18,20 +28,232 @@ from vllm.model_executor.layers.linear import ( QKVParallelLinear, RowParallelLinear, ) +from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, ) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalInputs, + MultiModalKwargsItems, + MultiModalUUIDDict, +) +from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptIndexTargets, + PromptReplacement, + PromptUpdate, +) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsQuant +from .interfaces_base import default_pooling_type +from .utils import AutoWeightsLoader, maybe_prefix from .vision import ( VisionEncoderInfo, VisionFeatureSelectStrategy, + VisionFeatureSelectStrategyStr, + get_num_selected_vision_tokens, resolve_visual_encoder_outputs, ) +class SiglipImagePixelInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height of each image + - w: Width of each image + """ + + type: Literal["pixel_values"] + data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] + + +_POOLING_TYPE_TO_STRATEGY: dict[str, VisionFeatureSelectStrategyStr] = { + "MEAN": "full", + "ALL": "full", + "CLS": "class", +} + + +def _get_vision_feature_select_strategy( + pooling_type: str, +) -> VisionFeatureSelectStrategyStr: + try: + return _POOLING_TYPE_TO_STRATEGY[pooling_type] + except KeyError: + raise ValueError( + f"No feature selection strategy is defined for " + f"pooling_type: {pooling_type!r}" + ) from None + + +class SiglipProcessingInfo(BaseProcessingInfo): + def get_hf_config(self): + return self.ctx.get_hf_config(SiglipConfig) + + def get_vision_encoder_info(self): + return SiglipEncoderInfo(self.get_hf_config()) + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(SiglipProcessor, **kwargs) + + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + return {"image": 1} + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + vision_encoder_info = self.get_vision_encoder_info() + + pooler_config = self.ctx.model_config.pooler_config + assert pooler_config is not None + + return get_num_selected_vision_tokens( + vision_encoder_info.get_num_image_tokens( + image_width=image_width, + image_height=image_height, + ), + _get_vision_feature_select_strategy(pooler_config.pooling_type), + ) + + def get_image_size_with_most_features(self) -> ImageSize: + vision_encoder_info = self.get_vision_encoder_info() + width = height = vision_encoder_info.get_image_size() + return ImageSize(width=width, height=height) + + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + return self.get_num_image_tokens( + image_width=target_width, image_height=target_height + ) + + +class SiglipDummyInputsBuilder(BaseDummyInputsBuilder[SiglipProcessingInfo]): + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + return "" + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, BaseDummyOptions] | None = None, + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + + target_width, target_height = self.info.get_image_size_with_most_features() + + image_overrides = mm_options.get("image") if mm_options else None + + return { + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ) + } + + +class SiglipMultiModalProcessor(BaseMultiModalProcessor[SiglipProcessingInfo]): + @cached_property + def image_token_id(self) -> int: + tokenizer = self.info.get_tokenizer() + dummy_token_id = 0 + + assert dummy_token_id not in tokenizer.all_special_ids + + return dummy_token_id + + def apply( + self, + prompt: str | list[int], + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object] | None = None, + *, + mm_uuids: MultiModalUUIDDict | None = None, + ) -> MultiModalInputs: + if prompt and mm_data: + raise ValueError( + "Siglip accepts text-only or image-only inputs, not both! " + "Image-only inputs means passing an image with an empty text " + "prompt." + ) + + if mm_data: + # For multi-modal data, the prompt after processing should + # only contain the image token + tokenization_kwargs = { + **(tokenization_kwargs or {}), + "add_special_tokens": False, + } + + return super().apply( + prompt=prompt, + mm_data=mm_data, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, + mm_uuids=mm_uuids, + ) + + def _hf_processor_applies_updates( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], + ) -> bool: + return False + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(pixel_values=MultiModalFieldConfig.batched("image")) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> list[PromptUpdate]: + image_token_id = self.image_token_id + + def get_replacement(item_idx: int): + images = mm_items.get_items("image", ImageProcessorItems) + image_size = images.get_image_size(item_idx) + + num_image_tokens = self.info.get_num_image_tokens( + image_width=image_size.width, image_height=image_size.height + ) + return [image_token_id] * num_image_tokens + + return [ + PromptReplacement( + modality="image", + target=PromptIndexTargets.start(), + replacement=get_replacement, + ), + ] + + class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]): def get_num_image_tokens( self, @@ -151,8 +373,9 @@ class SiglipVisionEmbeddings(nn.Module): class SiglipAttention(nn.Module): def __init__( self, - config: SiglipVisionConfig, + config: SiglipVisionConfig | SiglipTextConfig, quant_config: QuantizationConfig | None = None, + *, prefix: str = "", ) -> None: super().__init__() @@ -195,12 +418,29 @@ class SiglipAttention(nn.Module): def forward( self, hidden_states: torch.Tensor, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, None]: """Input shape: Batch x Time x Channel""" qkv_states, _ = self.qkv_proj(hidden_states) query_states, key_states, value_states = qkv_states.chunk(3, dim=-1) + needs_unsqueeze = query_states.ndim == 2 + if needs_unsqueeze: + query_states, key_states, value_states = ( + query_states.unsqueeze(0), + key_states.unsqueeze(0), + value_states.unsqueeze(0), + ) + out = self.attn(query_states, key_states, value_states) + + if needs_unsqueeze: + out, query_states, key_states, value_states = ( + out.squeeze(0), + query_states.squeeze(0), + key_states.squeeze(0), + value_states.squeeze(0), + ) + attn_output, _ = self.out_proj(out) return attn_output, None @@ -209,7 +449,7 @@ class SiglipAttention(nn.Module): class SiglipMLP(nn.Module): def __init__( self, - config: SiglipVisionConfig, + config: SiglipVisionConfig | SiglipTextConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: @@ -249,8 +489,9 @@ class SiglipMLP(nn.Module): class SiglipEncoderLayer(nn.Module): def __init__( self, - config: SiglipVisionConfig, + config: SiglipVisionConfig | SiglipTextConfig, quant_config: QuantizationConfig | None = None, + *, prefix: str = "", ) -> None: super().__init__() @@ -291,9 +532,10 @@ class SiglipEncoderLayer(nn.Module): class SiglipEncoder(nn.Module): def __init__( self, - config: SiglipVisionConfig, + config: SiglipVisionConfig | SiglipTextConfig, quant_config: QuantizationConfig | None = None, num_hidden_layers_override: int | None = None, + *, prefix: str = "", ) -> None: super().__init__() @@ -335,6 +577,76 @@ class SiglipEncoder(nn.Module): return hidden_states +class SiglipTextTransformer(nn.Module): + def __init__( + self, + config: SiglipTextConfig, + quant_config: QuantizationConfig | None = None, + *, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + embed_dim = config.hidden_size + + self.embeddings = SiglipTextEmbeddings(config) + + self.encoder = SiglipEncoder( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.encoder", + ) + + self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.head = nn.Linear(embed_dim, config.projection_size) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embeddings.token_embedding(input_ids) + + def forward( + self, + input_ids: torch.Tensor | None, + position_ids: torch.Tensor, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor: + hidden_states = self.embeddings(input_ids, position_ids, inputs_embeds) + + last_hidden_state = self.encoder( + inputs_embeds=hidden_states, return_all_hidden_states=False + ) + + last_hidden_state = self.final_layer_norm(last_hidden_state) + + return last_hidden_state + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + class SiglipMultiheadAttentionPoolingHead(nn.Module): """Multihead Attention Pooling.""" @@ -357,8 +669,9 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module): ) def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: - batch_size = hidden_state.shape[0] - probe = self.probe.repeat(batch_size, 1, 1) + batch_size = hidden_state.size(0) + + probe = self.probe.expand(batch_size, -1, -1) hidden_state = self.attention(probe, hidden_state, hidden_state)[0] @@ -367,7 +680,9 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module): hidden_state = self.mlp(hidden_state) hidden_state += residual - return hidden_state[:, 0] + pooled = hidden_state[:, 0] + + return pooled.unsqueeze(1) class SiglipVisionTransformer(nn.Module): @@ -420,6 +735,14 @@ class SiglipVisionTransformer(nn.Module): prefix=f"{prefix}.head", ) + @property + def dtype(self): + return next(self.parameters()).dtype + + @property + def device(self): + return next(self.parameters()).device + def forward( self, pixel_values: torch.Tensor, @@ -432,7 +755,6 @@ class SiglipVisionTransformer(nn.Module): pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, ) - # Produces either the last layer output or all of the hidden states, # depending on if we have select_layers or not encoder_outputs = self.encoder( @@ -440,21 +762,60 @@ class SiglipVisionTransformer(nn.Module): return_all_hidden_states=select_layers is not None, ) - # Handle post-norm (if applicable) and stacks feature layers if needed + if self.post_layernorm is not None: + encoder_outputs = self.post_layernorm(encoder_outputs) + + if self.use_head: + encoder_outputs = self.head(encoder_outputs) + + # stacks feature layers if needed encoder_outputs = resolve_visual_encoder_outputs( encoder_outputs, - self.post_layernorm, + None, select_layers=select_layers, max_possible_layers=self.config.num_hidden_layers, feature_select_strategy=feature_select_strategy, ) - # TODO: add this back when pooled_output is used in inference. - # if self.use_head: - # pooled_output = self.head(encoder_outputs) - return encoder_outputs + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + layer_count = len(self.encoder.layers) + + for name, loaded_weight in weights: + # post_layernorm is not needed in SiglipVisionTransformer + if name.startswith("post_layernorm") and self.post_layernorm is None: + continue + + # omit layers when num_hidden_layers_override is set + if name.startswith("encoder.layers"): + layer_idx = int(name.split(".")[2]) + if layer_idx >= layer_count: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class SiglipVisionModel(nn.Module): config_class = SiglipVisionConfig @@ -484,7 +845,11 @@ class SiglipVisionModel(nn.Module): @property def dtype(self): - return self.get_input_embeddings().weight.dtype + return self.vision_model.dtype + + @property + def device(self): + return self.vision_model.device def forward( self, @@ -555,3 +920,214 @@ class SiglipVisionModel(nn.Module): weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params + + +# Adapted from: https://github.com/huggingface/transformers/blob/v4.54.1/src/transformers/models/siglip/modeling_siglip.py#L200 +class SiglipTextEmbeddings(nn.Module): + def __init__(self, config: SiglipTextConfig): + super().__init__() + self.config = config + + self.token_embedding = VocabParallelEmbedding( + config.vocab_size, config.hidden_size + ) + + self.position_embedding = VocabParallelEmbedding( + config.max_position_embeddings, config.hidden_size + ) + + self.register_buffer( + "position_ids", + torch.arange(config.max_position_embeddings).expand((1, -1)), + persistent=False, + ) + + def forward( + self, + input_ids: torch.Tensor | None, + position_ids: torch.Tensor, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor: + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + return embeddings + + +# Assume EOS token corresponds to CLS token in text model +@default_pooling_type("CLS") +@MULTIMODAL_REGISTRY.register_processor( + SiglipMultiModalProcessor, + info=SiglipProcessingInfo, + dummy_inputs=SiglipDummyInputsBuilder, +) +class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): + is_pooling_model = True + + packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} + merge_by_field_config = True + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> str | None: + if modality.startswith("image"): + return None + + raise ValueError("Only image modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config: SiglipConfig = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config + self.multimodal_config = multimodal_config + + if hasattr(config, "num_labels"): + config.num_labels = 0 + + text_config = config.text_config + vision_config = config.vision_config + + self.text_embed_dim = text_config.hidden_size + self.vision_embed_dim = vision_config.hidden_size + + self.text_model = SiglipTextTransformer( + text_config, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "text_model"), + ) + self.vision_model = SiglipVisionTransformer( + vision_config, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "vision_model"), + ) + + self.text_projection_size = text_config.projection_size + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + self.pooler_config = pooler_config + + self.pooler = DispatchPooler( + { + "token_embed": Pooler.for_token_embed(pooler_config), + "embed": Pooler.for_embed(pooler_config), + } + ) + + self._is_text_input = True + + def get_text_features( + self, + input_ids: torch.Tensor | None, + position_ids: torch.Tensor, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor: + last_hidden_state = self.text_model( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + ) + text_features = self.text_model.head(last_hidden_state) + # Flip to extract CLS token (first token after reversal) for pooling + text_features = text_features.flip(0) + return text_features + + def get_image_features( + self, + pixel_values: torch.Tensor, + feature_select_strategy: VisionFeatureSelectStrategy | None = None, + ) -> torch.Tensor: + if feature_select_strategy is None: + feature_select_strategy = _get_vision_feature_select_strategy( + self.pooler_config.pooling_type + ) + + pooled_output = self.vision_model( + pixel_values=pixel_values, + select_layers=None, + feature_select_strategy=feature_select_strategy, + ) + + return pooled_output + + def _parse_and_validate_image_input( + self, **kwargs: object + ) -> SiglipImagePixelInputs | None: + pixel_values = kwargs.pop("pixel_values", None) + if pixel_values is None: + return None + + expected_h = expected_w = self.config.vision_config.image_size + return SiglipImagePixelInputs( + type="pixel_values", + data=pixel_values, + resolve_bindings={"h": expected_h, "w": expected_w}, + ) + + def _process_image_inputs(self, inputs: SiglipImagePixelInputs) -> torch.Tensor: + pixel_values = inputs["data"] + + return self.get_image_features(pixel_values) + + def get_language_model(self) -> torch.nn.Module: + return self.text_model + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: MultiModalEmbeddings | None = None, + *, + is_multimodal: torch.Tensor | None = None, + handle_oov_mm_token: bool = False, + ) -> torch.Tensor: + self._is_text_input = ( + multimodal_embeddings is None or len(multimodal_embeddings) == 0 + ) + + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) + + return super().get_input_embeddings( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) + + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return [] + + vision_embeddings = self._process_image_inputs(image_input) + return vision_embeddings + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ) -> torch.Tensor: + if intermediate_tensors is not None: + raise RuntimeError("PP is not supported for this model") + + # Multimodal inputs (image embeddings) + if not self._is_text_input: + return inputs_embeds + + return self.get_text_features(input_ids, positions, inputs_embeds) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader( + self, + skip_substrs=[".position_ids"], + ignore_unexpected_prefixes=["logit_scale.", "logit_bias."], + ) + + return loader.load_weights(weights) diff --git a/vllm/transformers_utils/chat_templates/registry.py b/vllm/transformers_utils/chat_templates/registry.py index dbb4ffb675b8b..3bdbe1d0a67b6 100644 --- a/vllm/transformers_utils/chat_templates/registry.py +++ b/vllm/transformers_utils/chat_templates/registry.py @@ -31,14 +31,15 @@ def _get_minicpmv_chat_template_fallback(tokenizer_name_or_path: str) -> Path | _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK: dict[str, ChatTemplatePath] = { "blip-2": CHAT_TEMPLATES_DIR / "template_blip2.jinja", - "clip": CHAT_TEMPLATES_DIR / "template_basic.jinja", "chameleon": CHAT_TEMPLATES_DIR / "template_basic.jinja", - "deepseek_vl_v2": CHAT_TEMPLATES_DIR / "template_deepseek_vl2.jinja", + "clip": CHAT_TEMPLATES_DIR / "template_basic.jinja", "deepseek_ocr": CHAT_TEMPLATES_DIR / "template_deepseek_ocr.jinja", + "deepseek_vl_v2": CHAT_TEMPLATES_DIR / "template_deepseek_vl2.jinja", "fuyu": CHAT_TEMPLATES_DIR / "template_fuyu.jinja", "minicpmv": _get_minicpmv_chat_template_fallback, "paligemma": CHAT_TEMPLATES_DIR / "template_basic.jinja", "qwen": _get_qwen_chat_template_fallback, + "siglip": CHAT_TEMPLATES_DIR / "template_basic.jinja", }