diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index fdfcf89d9ab34..60fe5b887952f 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -829,6 +829,7 @@ The following table lists those that are tested in vLLM. | Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | |--------------|--------|--------|-------------------|----------------------|---------------------------|---------------------| +| `CLIPModel` | CLIP | T / I | `openai/clip-vit-base-patch32`, `openai/clip-vit-large-patch14`, etc. | | | ✅︎ | | `LlavaNextForConditionalGeneration`C | LLaVA-NeXT-based | T / I | `royokong/e5-v` | | ✅︎ | ✅︎ | | `Phi3VForCausalLM`C | Phi-3-Vision-based | T + I | `TIGER-Lab/VLM2Vec-Full` | | ✅︎ | ✅︎ | | `*ForConditionalGeneration`C, `*ForCausalLM`C, etc. | Generative models | \* | N/A | \* | \* | \* | diff --git a/examples/offline_inference/vision_language_pooling.py b/examples/offline_inference/vision_language_pooling.py index 3d1daf4d19ff8..6f8679918c272 100644 --- a/examples/offline_inference/vision_language_pooling.py +++ b/examples/offline_inference/vision_language_pooling.py @@ -58,6 +58,30 @@ class ModelRequestData(NamedTuple): documents: Optional[ScoreMultiModalParam] = None +def run_clip(query: Query) -> ModelRequestData: + if query["modality"] == "text": + prompt = query["text"] + image = None + elif query["modality"] == "image": + prompt = "" # For image input, make sure that the prompt text is empty + image = query["image"] + else: + modality = query["modality"] + raise ValueError(f"Unsupported query modality: '{modality}'") + + engine_args = EngineArgs( + model="openai/clip-vit-base-patch32", + runner="pooling", + limit_mm_per_prompt={"image": 1}, + ) + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image=image, + ) + + def run_e5_v(query: Query) -> ModelRequestData: llama3_template = "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n" # noqa: E501 @@ -146,7 +170,8 @@ def run_vlm2vec_qwen2vl(query: Query) -> ModelRequestData: processor = AutoProcessor.from_pretrained( model_id, - # `min_pixels` and `max_pixels` are deprecated + # `min_pixels` and `max_pixels` are deprecated for + # transformers `preprocessor_config.json` size={"shortest_edge": 3136, "longest_edge": 12845056}, ) processor.chat_template = load_chat_template( @@ -172,8 +197,10 @@ def run_vlm2vec_qwen2vl(query: Query) -> ModelRequestData: model=merged_path, runner="pooling", max_model_len=4096, - trust_remote_code=True, - mm_processor_kwargs={"num_crops": 4}, + mm_processor_kwargs={ + "min_pixels": 3136, + "max_pixels": 12845056, + }, limit_mm_per_prompt={"image": 1}, ) @@ -299,6 +326,7 @@ def run_score(model: str, modality: QueryModality, seed: Optional[int]): model_example_map = { + "clip": run_clip, "e5_v": run_e5_v, "vlm2vec_phi3v": run_vlm2vec_phi3v, "vlm2vec_qwen2vl": run_vlm2vec_qwen2vl, diff --git a/examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py b/examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py index 6e31c3836806f..16ac4378c6863 100644 --- a/examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py +++ b/examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py @@ -1,14 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # ruff: noqa: E501 -"""Example Python client for multimodal embedding API using vLLM API server -NOTE: - start a supported multimodal embeddings model server with `vllm serve`, e.g. - vllm serve TIGER-Lab/VLM2Vec-Full \ - --runner pooling \ - --trust-remote-code \ - --max-model-len 4096 \ - --chat-template examples/template_vlm2vec_phi3v.jinja +"""Example Python client for multimodal embedding API using vLLM API server. + +Refer to each `run_*` function for the command to run the server for that model. """ import argparse @@ -47,7 +42,58 @@ def create_chat_embeddings( ) +def run_clip(client: OpenAI, model: str): + """ + Start the server using: + + vllm serve openai/clip-vit-base-patch32 \ + --runner pooling + """ + + response = create_chat_embeddings( + client, + messages=[ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + ], + } + ], + model=model, + encoding_format="float", + ) + + print("Image embedding output:", response.data[0].embedding) + + response = create_chat_embeddings( + client, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "a photo of a cat"}, + ], + } + ], + model=model, + encoding_format="float", + ) + + print("Text embedding output:", response.data[0].embedding) + + def run_vlm2vec(client: OpenAI, model: str): + """ + Start the server using: + + vllm serve TIGER-Lab/VLM2Vec-Full \ + --runner pooling \ + --trust-remote-code \ + --max-model-len 4096 \ + --chat-template examples/template_vlm2vec_phi3v.jinja + """ + response = create_chat_embeddings( client, messages=[ @@ -103,6 +149,15 @@ def run_vlm2vec(client: OpenAI, model: str): def run_dse_qwen2_vl(client: OpenAI, model: str): + """ + Start the server using: + + vllm serve MrLight/dse-qwen2-2b-mrl-v1 \ + --runner pooling \ + --trust-remote-code \ + --max-model-len 8192 \ + --chat-template examples/template_dse_qwen2_vl.jinja + """ response = create_chat_embeddings( client, messages=[ @@ -156,6 +211,7 @@ def run_dse_qwen2_vl(client: OpenAI, model: str): model_example_map = { + "clip": run_clip, "vlm2vec": run_vlm2vec, "dse_qwen2_vl": run_dse_qwen2_vl, } diff --git a/tests/models/multimodal/pooling/test_clip.py b/tests/models/multimodal/pooling/test_clip.py new file mode 100644 index 0000000000000..0aaf6877c2a6f --- /dev/null +++ b/tests/models/multimodal/pooling/test_clip.py @@ -0,0 +1,138 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +from transformers import CLIPModel + +from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner +from ...utils import check_embeddings_close + +HF_TEXT_PROMPTS = [ + "a photo of a stop sign", + "a photo of a cherry blossom", +] + +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ + "stop_sign": "", + "cherry_blossom": "", +}) + +MODELS = ["openai/clip-vit-base-patch32"] + + +def _run_test( + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + input_texts: list[str], + input_images: PromptImageInput, + model: str, + *, + dtype: str, +) -> None: + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). + with vllm_runner(model, + runner="pooling", + dtype=dtype, + enforce_eager=True, + max_model_len=77) as vllm_model: + vllm_outputs = vllm_model.embed(input_texts, images=input_images) + + with hf_runner(model, dtype=dtype, auto_cls=CLIPModel) as hf_model: + all_inputs = hf_model.get_inputs(input_texts, images=input_images) + + all_outputs = [] + for inputs in all_inputs: + if "pixel_values" in inputs: + inputs.pop("input_ids") + pooled_output = hf_model.model.get_image_features( + **hf_model.wrap_device(inputs)).squeeze(0) + else: + pooled_output = hf_model.model.get_text_features( + **hf_model.wrap_device(inputs)).squeeze(0) + + all_outputs.append(pooled_output.tolist()) + + hf_outputs = all_outputs + + check_embeddings_close( + embeddings_0_lst=hf_outputs, + embeddings_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_models_text( + hf_runner, + vllm_runner, + image_assets, + model: str, + dtype: str, +) -> None: + input_texts_images = [(text, None) for text in HF_TEXT_PROMPTS] + input_texts = [text for text, _ in input_texts_images] + input_images = [image for _, image in input_texts_images] + + _run_test( + hf_runner, + vllm_runner, + input_texts, + input_images, # type: ignore + model, + dtype=dtype, + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_models_image( + hf_runner, + vllm_runner, + image_assets, + model: str, + dtype: str, +) -> None: + input_texts_images = [ + (text, asset.pil_image) + for text, asset in zip(HF_IMAGE_PROMPTS, image_assets) + ] + input_texts = [text for text, _ in input_texts_images] + input_images = [image for _, image in input_texts_images] + + _run_test( + hf_runner, + vllm_runner, + input_texts, + input_images, + model, + dtype=dtype, + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_models_text_image_no_crash( + vllm_runner, + image_assets, + model: str, + dtype: str, +) -> None: + texts = [HF_TEXT_PROMPTS[0]] + images = [image_assets[0].pil_image] + + with vllm_runner(model, + runner="pooling", + dtype=dtype, + enforce_eager=True, + max_model_len=77) as vllm_model: + with pytest.raises(ValueError, match="not both"): + vllm_model.embed(texts, images=images) + + # Should still be able to run subsequent requests + vllm_model.embed(texts) + vllm_model.embed([""], images=images) diff --git a/tests/models/registry.py b/tests/models/registry.py index 86a8359752278..182654cdf3c7b 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -389,6 +389,7 @@ _EMBEDDING_EXAMPLE_MODELS = { "RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"), # noqa: E501 "XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small"), # noqa: E501 # [Multimodal] + "CLIPModel": _HfExamplesInfo("openai/clip-vit-base-patch32"), "LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"), "Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full", trust_remote_code=True), @@ -687,7 +688,11 @@ class HfExampleModels: return self.hf_models.keys() def get_hf_info(self, model_arch: str) -> _HfExamplesInfo: - return self.hf_models[model_arch] + try: + return self.hf_models[model_arch] + except KeyError: + raise ValueError(f"No example model defined for {model_arch}; " + f"please update this file.") from None def find_hf_info(self, model_id: str) -> _HfExamplesInfo: for info in self.hf_models.values(): @@ -699,7 +704,8 @@ class HfExampleModels: if any(extra == model_id for extra in info.extras.values()): return info - raise ValueError(f"No example model defined for {model_id}") + raise ValueError(f"No example model defined for {model_id}; " + f"please update this file.") HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index ac34f279d0b57..6632ee6b0dc35 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -417,12 +417,16 @@ class MultiHeadAttention(nn.Module): head_size: int, scale: float, num_kv_heads: Optional[int] = None, - ): + # This has no effect, it is only here to make it easier to swap + # between Attention and MultiHeadAttention + prefix: str = "", + ) -> None: super().__init__() self.num_heads = num_heads self.head_size = head_size self.scale = scale self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.layer_name = prefix assert self.num_heads % self.num_kv_heads == 0, \ f"num_heads ({self.num_heads}) is not " \ diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 2ec3edc5a0a7a..10e7186671220 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -351,7 +351,7 @@ class BertModel(nn.Module, SupportsQuant): prefix=f"{prefix}.encoder") def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embeddings(input_ids) + return self.embeddings.word_embeddings(input_ids) def forward( self, diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 451da21200488..7ec366a2e4aa9 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -1,28 +1,63 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Minimal implementation of CLIPVisionModel intended to be only used -within a vision language model.""" -from collections.abc import Iterable -from typing import Optional, Union +from collections.abc import Iterable, Mapping, Sequence +from functools import cached_property +from typing import Annotated, Literal, Optional, Union import torch import torch.nn as nn -from transformers import CLIPVisionConfig +from transformers import (BatchFeature, CLIPConfig, CLIPProcessor, + CLIPTextConfig, CLIPVisionConfig) +from vllm.attention import Attention from vllm.attention.layer import MultiHeadAttention +from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) +from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import SupportsQuant +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalInputs, MultiModalKwargsItems, + MultiModalUUIDDict) +from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, + MultiModalDataItems) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptIndexTargets, + PromptReplacement, PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape +from .interfaces import MultiModalEmbeddings, SupportsMultiModal +from .interfaces_base import default_pooling_type +from .utils import AutoWeightsLoader, maybe_prefix from .vision import (VisionEncoderInfo, VisionFeatureSelectStrategy, + VisionFeatureSelectStrategyStr, + get_num_selected_vision_tokens, resolve_visual_encoder_outputs) +class CLIPImagePixelInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height of each image + - w: Width of each image + """ + type: Literal["pixel_values"] + data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] + + class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]): def get_num_image_tokens( @@ -45,7 +80,214 @@ class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]): return image_size // patch_size -# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa +_POOLING_TYPE_TO_STRATEGY: dict[str, VisionFeatureSelectStrategyStr] = { + "MEAN": "full", + "ALL": "full", + "CLS": "class", + # This lets us use the same pooling type for both text and image + "LAST": "class", +} + + +def _get_vision_feature_select_strategy(pooling_type: str): + try: + return _POOLING_TYPE_TO_STRATEGY[pooling_type] + except KeyError: + raise ValueError(f"No feature selection strategy is defined for " + f"pooling_type: {pooling_type!r}") from None + + +class CLIPProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(CLIPConfig) + + def get_vision_encoder_info(self): + return CLIPEncoderInfo(self.get_hf_config()) + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(CLIPProcessor, **kwargs) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": 1} + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + vision_encoder_info = self.get_vision_encoder_info() + + pooler_config = self.ctx.model_config.pooler_config + assert pooler_config is not None + + return get_num_selected_vision_tokens( + vision_encoder_info.get_num_image_tokens( + image_width=image_width, + image_height=image_height, + ), + _get_vision_feature_select_strategy(pooler_config.pooling_type), + ) + + def get_image_size_with_most_features(self) -> ImageSize: + vision_encoder_info = self.get_vision_encoder_info() + width = height = vision_encoder_info.get_image_size() + return ImageSize(width=width, height=height) + + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + return self.get_num_image_tokens( + image_width=target_width, + image_height=target_height, + ) + + +class CLIPDummyInputsBuilder(BaseDummyInputsBuilder[CLIPProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + return "" + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + + target_width, target_height = \ + self.info.get_image_size_with_most_features() + + image_overrides = mm_options.get("image") if mm_options else None + + return { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides) + } + + +class CLIPMultiModalProcessor(BaseMultiModalProcessor[CLIPProcessingInfo]): + + @cached_property + def image_token_id(self) -> int: + tokenizer = self.info.get_tokenizer() + dummy_token_id = 0 + + assert dummy_token_id not in tokenizer.all_special_ids + + return dummy_token_id + + def apply( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Optional[Mapping[str, object]] = None, + *, + mm_uuids: Optional[MultiModalUUIDDict] = None, + ) -> MultiModalInputs: + if prompt and mm_data: + raise ValueError( + "CLIP accepts text-only or image-only inputs, not both! " + "Image-only inputs means passing an image with an empty text " + "prompt.") + + if mm_data: + # For multi-modal data, the prompt after processing should + # only contain the dummy image tokens + tokenization_kwargs = { + **(tokenization_kwargs or {}), + "add_special_tokens": False, + } + + return super().apply( + prompt=prompt, + mm_data=mm_data, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, + mm_uuids=mm_uuids, + ) + + def _hf_processor_applies_updates( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], + ) -> bool: + return False + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(pixel_values=MultiModalFieldConfig.batched("image")) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + image_token_id = self.image_token_id + + def get_replacement(item_idx: int): + images = mm_items.get_items("image", ImageProcessorItems) + image_size = images.get_image_size(item_idx) + + num_image_tokens = self.info.get_num_image_tokens( + image_width=image_size.width, + image_height=image_size.height, + ) + return [image_token_id] * num_image_tokens + + return [ + PromptReplacement( + modality="image", + target=PromptIndexTargets.start(), + replacement=get_replacement, + ), + ] + + +# Adapted from: https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/models/clip/modeling_clip.py +class CLIPTextEmbeddings(nn.Module): + + def __init__(self, config: CLIPTextConfig): + super().__init__() + + embed_dim = config.hidden_size + + self.token_embedding = VocabParallelEmbedding(config.vocab_size, + embed_dim) + self.position_embedding = VocabParallelEmbedding( + config.max_position_embeddings, embed_dim) + + def forward( + self, + input_ids: Optional[torch.Tensor], + position_ids: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if inputs_embeds is None: + if input_ids is None: + raise ValueError( + "Either `input_ids` or `input_embeds` must be provided") + + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + class CLIPVisionEmbeddings(nn.Module): def __init__(self, config: CLIPVisionConfig): @@ -89,15 +331,17 @@ class CLIPVisionEmbeddings(nn.Module): class CLIPAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__( self, - config: CLIPVisionConfig, + config: Union[CLIPTextConfig, CLIPVisionConfig], quant_config: Optional[QuantizationConfig] = None, + *, prefix: str = "", - ): + attn_cls: Union[type[Attention], type[MultiHeadAttention]], + ) -> None: super().__init__() + self.config = config self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads @@ -127,8 +371,12 @@ class CLIPAttention(nn.Module): self.tp_size = get_tensor_model_parallel_world_size() self.num_heads_per_partition = divide(self.num_heads, self.tp_size) - self.attn = MultiHeadAttention(self.num_heads_per_partition, - self.head_dim, self.scale) + self.attn = attn_cls( + self.num_heads_per_partition, + self.head_dim, + self.scale, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -148,7 +396,7 @@ class CLIPMLP(nn.Module): def __init__( self, - config: CLIPVisionConfig, + config: Union[CLIPTextConfig, CLIPVisionConfig], quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: @@ -178,15 +426,18 @@ class CLIPEncoderLayer(nn.Module): def __init__( self, - config: CLIPVisionConfig, + config: Union[CLIPTextConfig, CLIPVisionConfig], quant_config: Optional[QuantizationConfig] = None, + *, prefix: str = "", + attn_cls: Union[type[Attention], type[MultiHeadAttention]], ) -> None: super().__init__() self.self_attn = CLIPAttention( config, quant_config=quant_config, prefix=f"{prefix}.self_attn", + attn_cls=attn_cls, ) self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -223,10 +474,12 @@ class CLIPEncoder(nn.Module): def __init__( self, - config: CLIPVisionConfig, + config: Union[CLIPTextConfig, CLIPVisionConfig], quant_config: Optional[QuantizationConfig] = None, num_hidden_layers_override: Optional[int] = None, + *, prefix: str = "", + attn_cls: Union[type[Attention], type[MultiHeadAttention]], ) -> None: super().__init__() @@ -239,12 +492,15 @@ class CLIPEncoder(nn.Module): self.layers = nn.ModuleList([ CLIPEncoderLayer(config=config, quant_config=quant_config, - prefix=f"{prefix}.layers.{layer_idx}") + prefix=f"{prefix}.layers.{layer_idx}", + attn_cls=attn_cls) for layer_idx in range(num_hidden_layers) ]) def forward( - self, inputs_embeds: torch.Tensor, return_all_hidden_states: bool + self, + inputs_embeds: torch.Tensor, + return_all_hidden_states: bool, ) -> Union[torch.Tensor, list[torch.Tensor]]: hidden_states_pool = [inputs_embeds] hidden_states = inputs_embeds @@ -260,6 +516,87 @@ class CLIPEncoder(nn.Module): return hidden_states +class CLIPTextTransformer(nn.Module): + + def __init__( + self, + config: CLIPTextConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + embed_dim = config.hidden_size + + self.embeddings = CLIPTextEmbeddings(config) + + self.encoder = CLIPEncoder( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.encoder", + attn_cls=Attention, + ) + + self.final_layer_norm = nn.LayerNorm( + embed_dim, + eps=config.layer_norm_eps, + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embeddings.token_embedding(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + position_ids: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden_states = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + ) + + last_hidden_state = self.encoder( + inputs_embeds=hidden_states, + return_all_hidden_states=False, + ) + last_hidden_state = self.final_layer_norm(last_hidden_state) + + return last_hidden_state + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + class CLIPVisionTransformer(nn.Module): def __init__( @@ -287,6 +624,7 @@ class CLIPVisionTransformer(nn.Module): quant_config=quant_config, num_hidden_layers_override=num_hidden_layers_override, prefix=f"{prefix}.encoder", + attn_cls=MultiHeadAttention, ) num_hidden_layers = config.num_hidden_layers @@ -306,6 +644,14 @@ class CLIPVisionTransformer(nn.Module): else: self.post_layernorm = None + @property + def dtype(self): + return next(self.parameters()).dtype + + @property + def device(self): + return next(self.parameters()).device + def forward( self, pixel_values: torch.Tensor, @@ -335,47 +681,6 @@ class CLIPVisionTransformer(nn.Module): return encoder_outputs - -class CLIPVisionModel(nn.Module, SupportsQuant): - config_class = CLIPVisionConfig - main_input_name = "pixel_values" - packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} - - def __init__( - self, - config: CLIPVisionConfig, - quant_config: Optional[QuantizationConfig] = None, - *, - num_hidden_layers_override: Optional[int] = None, - require_post_norm: Optional[bool] = None, - prefix: str = "", - ) -> None: - super().__init__() - self.vision_model = CLIPVisionTransformer( - config=config, - quant_config=quant_config, - num_hidden_layers_override=num_hidden_layers_override, - require_post_norm=require_post_norm, - prefix=f"{prefix}.vision_model") - - def forward( - self, - pixel_values: torch.Tensor, - select_layers: Optional[list[int]] = None, - feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None, - ) -> torch.Tensor: - return self.vision_model( - pixel_values, - select_layers=select_layers, - feature_select_strategy=feature_select_strategy, - ) - - @property - def device(self): - return next(self.parameters()).device - - # (TODO) Add prefix argument for filtering out weights to be loaded - # ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986 def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ @@ -386,17 +691,17 @@ class CLIPVisionModel(nn.Module, SupportsQuant): ] params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() - layer_count = len(self.vision_model.encoder.layers) + layer_count = len(self.encoder.layers) for name, loaded_weight in weights: # post_layernorm is not needed in CLIPVisionModel - if (name.startswith("vision_model.post_layernorm") - and self.vision_model.post_layernorm is None): + if (name.startswith("post_layernorm") + and self.post_layernorm is None): continue # omit layers when num_hidden_layers_override is set - if name.startswith("vision_model.encoder.layers"): - layer_idx = int(name.split(".")[3]) + if name.startswith("encoder.layers"): + layer_idx = int(name.split(".")[2]) if layer_idx >= layer_count: continue @@ -416,3 +721,233 @@ class CLIPVisionModel(nn.Module, SupportsQuant): weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params + + +class CLIPVisionModel(nn.Module): + + def __init__( + self, + config: CLIPVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + num_hidden_layers_override: Optional[int] = None, + require_post_norm: Optional[bool] = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.vision_model = CLIPVisionTransformer( + config=config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers_override, + require_post_norm=require_post_norm, + prefix=f"{prefix}.vision_model", + ) + + def forward( + self, + pixel_values: torch.Tensor, + select_layers: Optional[list[int]] = None, + feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None, + ) -> torch.Tensor: + return self.vision_model( + pixel_values, + select_layers=select_layers, + feature_select_strategy=feature_select_strategy, + ) + + @property + def dtype(self): + return self.vision_model.dtype + + @property + def device(self): + return self.vision_model.device + + +# Assume EOS token corresponds to LAST token in text model +@default_pooling_type("LAST") +@MULTIMODAL_REGISTRY.register_processor(CLIPMultiModalProcessor, + info=CLIPProcessingInfo, + dummy_inputs=CLIPDummyInputsBuilder) +class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): + + is_pooling_model = True + + packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} + merge_by_field_config = True + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("image"): + return None + + raise ValueError("Only image modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config: CLIPConfig = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config + self.multimodal_config = multimodal_config + + text_config = config.text_config + vision_config = config.vision_config + + self.projection_dim = config.projection_dim + self.text_embed_dim = text_config.hidden_size + self.vision_embed_dim = vision_config.hidden_size + + self.text_model = CLIPTextTransformer( + text_config, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "text_model"), + ) + self.vision_model = CLIPVisionTransformer( + vision_config, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "vision_model"), + ) + + self.visual_projection = nn.Linear( + self.vision_embed_dim, + self.projection_dim, + bias=False, + ) + self.text_projection = nn.Linear( + self.text_embed_dim, + self.projection_dim, + bias=False, + ) + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + self.pooler_config = pooler_config + + self.pooler = DispatchPooler({ + "encode": Pooler.for_encode(pooler_config), + "embed": Pooler.for_embed(pooler_config), + }) + + # Assumes that self.forward is called after self.get_input_embeddings + self._is_text_input = True + + def get_text_features( + self, + input_ids: Optional[torch.Tensor], + position_ids: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + pooled_output = self.text_model( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + ) + + text_features = self.text_projection(pooled_output) + + return text_features + + def get_image_features( + self, + pixel_values: torch.Tensor, + feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None, + ) -> torch.Tensor: + if feature_select_strategy is None: + feature_select_strategy = _get_vision_feature_select_strategy( + self.pooler_config.pooling_type) + + pooled_output = self.vision_model( + pixel_values=pixel_values, + select_layers=None, + feature_select_strategy=feature_select_strategy, + ) + + image_features = self.visual_projection(pooled_output) + + return image_features + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[CLIPImagePixelInputs]: + pixel_values = kwargs.pop("pixel_values", None) + if pixel_values is None: + return None + + expected_h = expected_w = self.config.vision_config.image_size + return CLIPImagePixelInputs(type="pixel_values", + data=pixel_values, + resolve_bindings={ + "h": expected_h, + "w": expected_w + }) + + def _process_image_inputs(self, + inputs: CLIPImagePixelInputs) -> torch.Tensor: + pixel_values = inputs["data"] + + return self.get_image_features(pixel_values) + + def get_language_model(self) -> torch.nn.Module: + return self.text_model + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, + ) -> torch.Tensor: + self._is_text_input = (multimodal_embeddings is None + or len(multimodal_embeddings) == 0) + + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) + + return super().get_input_embeddings( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) + + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return [] + + vision_embeddings = self._process_image_inputs(image_input) + return vision_embeddings + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> torch.Tensor: + if intermediate_tensors is not None: + raise RuntimeError("PP is not supported for this model") + + # Multimodal inputs + if not self._is_text_input: + return inputs_embeds + + # Text inputs + return self.get_text_features(input_ids=input_ids, + position_ids=positions, + inputs_embeds=inputs_embeds) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader( + self, + skip_substrs=[".position_ids"], + ignore_unexpected_prefixes=["logit_scale."], + ) + + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 94744fe558bd9..bc2dc697d1c5f 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -187,6 +187,7 @@ _EMBEDDING_MODELS = { "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"), "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"), # [Multimodal] + "CLIPModel": ("clip", "CLIPEmbeddingModel"), "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index 2636942580fab..b4007ff2e1cf1 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -92,8 +92,10 @@ def get_vit_attn_backend(head_size: int, dtype: torch.dtype) -> _Backend: return current_platform.get_vit_attn_backend(head_size, dtype) +VisionFeatureSelectStrategyStr = Literal["class", "default", "full"] + VisionFeatureSelectStrategy = Union[ - Literal["class", "default", "full"], + VisionFeatureSelectStrategyStr, Callable[[torch.Tensor], torch.Tensor], ] @@ -106,7 +108,7 @@ def _get_vision_feature_selector( # https://github.com/huggingface/transformers/blob/cd74917ffc3e8f84e4a886052c5ab32b7ac623cc/src/transformers/models/clip/modeling_clip.py#L762 if strategy == "class": - return lambda feats: feats[:, 0, :] + return lambda feats: feats[:, :1, :] # https://github.com/huggingface/transformers/blob/4a02bc7004285bdb12cc033e87ad2578ce2fa900/src/transformers/models/llava/modeling_llava.py#L196 if strategy == "default": diff --git a/vllm/transformers_utils/chat_templates/registry.py b/vllm/transformers_utils/chat_templates/registry.py index 3a97f2c056181..d24a0946bdde0 100644 --- a/vllm/transformers_utils/chat_templates/registry.py +++ b/vllm/transformers_utils/chat_templates/registry.py @@ -33,6 +33,7 @@ def _get_minicpmv_chat_template_fallback( # yapf: disable _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK: dict[str, ChatTemplatePath] = { "blip-2": CHAT_TEMPLATES_DIR / "template_blip2.jinja", + "clip": CHAT_TEMPLATES_DIR / "template_basic.jinja", "chameleon": CHAT_TEMPLATES_DIR / "template_basic.jinja", "deepseek_vl_v2": CHAT_TEMPLATES_DIR / "template_deepseek_vl2.jinja", "fuyu": CHAT_TEMPLATES_DIR / "template_fuyu.jinja",