mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-26 16:54:28 +08:00
[Model] CLIP Embedding Support (#26010)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
2a6dc67eb5
commit
4570535ec4
@ -829,6 +829,7 @@ The following table lists those that are tested in vLLM.
|
||||
|
||||
| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|
||||
|--------------|--------|--------|-------------------|----------------------|---------------------------|---------------------|
|
||||
| `CLIPModel` | CLIP | T / I | `openai/clip-vit-base-patch32`, `openai/clip-vit-large-patch14`, etc. | | | ✅︎ |
|
||||
| `LlavaNextForConditionalGeneration`<sup>C</sup> | LLaVA-NeXT-based | T / I | `royokong/e5-v` | | ✅︎ | ✅︎ |
|
||||
| `Phi3VForCausalLM`<sup>C</sup> | Phi-3-Vision-based | T + I | `TIGER-Lab/VLM2Vec-Full` | | ✅︎ | ✅︎ |
|
||||
| `*ForConditionalGeneration`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | \* | N/A | \* | \* | \* |
|
||||
|
||||
@ -58,6 +58,30 @@ class ModelRequestData(NamedTuple):
|
||||
documents: Optional[ScoreMultiModalParam] = None
|
||||
|
||||
|
||||
def run_clip(query: Query) -> ModelRequestData:
|
||||
if query["modality"] == "text":
|
||||
prompt = query["text"]
|
||||
image = None
|
||||
elif query["modality"] == "image":
|
||||
prompt = "" # For image input, make sure that the prompt text is empty
|
||||
image = query["image"]
|
||||
else:
|
||||
modality = query["modality"]
|
||||
raise ValueError(f"Unsupported query modality: '{modality}'")
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model="openai/clip-vit-base-patch32",
|
||||
runner="pooling",
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
)
|
||||
|
||||
|
||||
def run_e5_v(query: Query) -> ModelRequestData:
|
||||
llama3_template = "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n" # noqa: E501
|
||||
|
||||
@ -146,7 +170,8 @@ def run_vlm2vec_qwen2vl(query: Query) -> ModelRequestData:
|
||||
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
model_id,
|
||||
# `min_pixels` and `max_pixels` are deprecated
|
||||
# `min_pixels` and `max_pixels` are deprecated for
|
||||
# transformers `preprocessor_config.json`
|
||||
size={"shortest_edge": 3136, "longest_edge": 12845056},
|
||||
)
|
||||
processor.chat_template = load_chat_template(
|
||||
@ -172,8 +197,10 @@ def run_vlm2vec_qwen2vl(query: Query) -> ModelRequestData:
|
||||
model=merged_path,
|
||||
runner="pooling",
|
||||
max_model_len=4096,
|
||||
trust_remote_code=True,
|
||||
mm_processor_kwargs={"num_crops": 4},
|
||||
mm_processor_kwargs={
|
||||
"min_pixels": 3136,
|
||||
"max_pixels": 12845056,
|
||||
},
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
@ -299,6 +326,7 @@ def run_score(model: str, modality: QueryModality, seed: Optional[int]):
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"clip": run_clip,
|
||||
"e5_v": run_e5_v,
|
||||
"vlm2vec_phi3v": run_vlm2vec_phi3v,
|
||||
"vlm2vec_qwen2vl": run_vlm2vec_qwen2vl,
|
||||
|
||||
@ -1,14 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: E501
|
||||
"""Example Python client for multimodal embedding API using vLLM API server
|
||||
NOTE:
|
||||
start a supported multimodal embeddings model server with `vllm serve`, e.g.
|
||||
vllm serve TIGER-Lab/VLM2Vec-Full \
|
||||
--runner pooling \
|
||||
--trust-remote-code \
|
||||
--max-model-len 4096 \
|
||||
--chat-template examples/template_vlm2vec_phi3v.jinja
|
||||
"""Example Python client for multimodal embedding API using vLLM API server.
|
||||
|
||||
Refer to each `run_*` function for the command to run the server for that model.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@ -47,7 +42,58 @@ def create_chat_embeddings(
|
||||
)
|
||||
|
||||
|
||||
def run_clip(client: OpenAI, model: str):
|
||||
"""
|
||||
Start the server using:
|
||||
|
||||
vllm serve openai/clip-vit-base-patch32 \
|
||||
--runner pooling
|
||||
"""
|
||||
|
||||
response = create_chat_embeddings(
|
||||
client,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
],
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
encoding_format="float",
|
||||
)
|
||||
|
||||
print("Image embedding output:", response.data[0].embedding)
|
||||
|
||||
response = create_chat_embeddings(
|
||||
client,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "a photo of a cat"},
|
||||
],
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
encoding_format="float",
|
||||
)
|
||||
|
||||
print("Text embedding output:", response.data[0].embedding)
|
||||
|
||||
|
||||
def run_vlm2vec(client: OpenAI, model: str):
|
||||
"""
|
||||
Start the server using:
|
||||
|
||||
vllm serve TIGER-Lab/VLM2Vec-Full \
|
||||
--runner pooling \
|
||||
--trust-remote-code \
|
||||
--max-model-len 4096 \
|
||||
--chat-template examples/template_vlm2vec_phi3v.jinja
|
||||
"""
|
||||
|
||||
response = create_chat_embeddings(
|
||||
client,
|
||||
messages=[
|
||||
@ -103,6 +149,15 @@ def run_vlm2vec(client: OpenAI, model: str):
|
||||
|
||||
|
||||
def run_dse_qwen2_vl(client: OpenAI, model: str):
|
||||
"""
|
||||
Start the server using:
|
||||
|
||||
vllm serve MrLight/dse-qwen2-2b-mrl-v1 \
|
||||
--runner pooling \
|
||||
--trust-remote-code \
|
||||
--max-model-len 8192 \
|
||||
--chat-template examples/template_dse_qwen2_vl.jinja
|
||||
"""
|
||||
response = create_chat_embeddings(
|
||||
client,
|
||||
messages=[
|
||||
@ -156,6 +211,7 @@ def run_dse_qwen2_vl(client: OpenAI, model: str):
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"clip": run_clip,
|
||||
"vlm2vec": run_vlm2vec,
|
||||
"dse_qwen2_vl": run_dse_qwen2_vl,
|
||||
}
|
||||
|
||||
138
tests/models/multimodal/pooling/test_clip.py
Normal file
138
tests/models/multimodal/pooling/test_clip.py
Normal file
@ -0,0 +1,138 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
from transformers import CLIPModel
|
||||
|
||||
from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
|
||||
from ...utils import check_embeddings_close
|
||||
|
||||
HF_TEXT_PROMPTS = [
|
||||
"a photo of a stop sign",
|
||||
"a photo of a cherry blossom",
|
||||
]
|
||||
|
||||
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||
"stop_sign": "",
|
||||
"cherry_blossom": "",
|
||||
})
|
||||
|
||||
MODELS = ["openai/clip-vit-base-patch32"]
|
||||
|
||||
|
||||
def _run_test(
|
||||
hf_runner: type[HfRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
input_texts: list[str],
|
||||
input_images: PromptImageInput,
|
||||
model: str,
|
||||
*,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||
# vLLM needs a fresh new process without cuda initialization.
|
||||
# if we run HF first, the cuda initialization will be done and it
|
||||
# will hurt multiprocessing backend with fork method (the default method).
|
||||
with vllm_runner(model,
|
||||
runner="pooling",
|
||||
dtype=dtype,
|
||||
enforce_eager=True,
|
||||
max_model_len=77) as vllm_model:
|
||||
vllm_outputs = vllm_model.embed(input_texts, images=input_images)
|
||||
|
||||
with hf_runner(model, dtype=dtype, auto_cls=CLIPModel) as hf_model:
|
||||
all_inputs = hf_model.get_inputs(input_texts, images=input_images)
|
||||
|
||||
all_outputs = []
|
||||
for inputs in all_inputs:
|
||||
if "pixel_values" in inputs:
|
||||
inputs.pop("input_ids")
|
||||
pooled_output = hf_model.model.get_image_features(
|
||||
**hf_model.wrap_device(inputs)).squeeze(0)
|
||||
else:
|
||||
pooled_output = hf_model.model.get_text_features(
|
||||
**hf_model.wrap_device(inputs)).squeeze(0)
|
||||
|
||||
all_outputs.append(pooled_output.tolist())
|
||||
|
||||
hf_outputs = all_outputs
|
||||
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=hf_outputs,
|
||||
embeddings_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
def test_models_text(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
input_texts_images = [(text, None) for text in HF_TEXT_PROMPTS]
|
||||
input_texts = [text for text, _ in input_texts_images]
|
||||
input_images = [image for _, image in input_texts_images]
|
||||
|
||||
_run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
input_texts,
|
||||
input_images, # type: ignore
|
||||
model,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
def test_models_image(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
input_texts_images = [
|
||||
(text, asset.pil_image)
|
||||
for text, asset in zip(HF_IMAGE_PROMPTS, image_assets)
|
||||
]
|
||||
input_texts = [text for text, _ in input_texts_images]
|
||||
input_images = [image for _, image in input_texts_images]
|
||||
|
||||
_run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
input_texts,
|
||||
input_images,
|
||||
model,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
def test_models_text_image_no_crash(
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
texts = [HF_TEXT_PROMPTS[0]]
|
||||
images = [image_assets[0].pil_image]
|
||||
|
||||
with vllm_runner(model,
|
||||
runner="pooling",
|
||||
dtype=dtype,
|
||||
enforce_eager=True,
|
||||
max_model_len=77) as vllm_model:
|
||||
with pytest.raises(ValueError, match="not both"):
|
||||
vllm_model.embed(texts, images=images)
|
||||
|
||||
# Should still be able to run subsequent requests
|
||||
vllm_model.embed(texts)
|
||||
vllm_model.embed([""], images=images)
|
||||
@ -389,6 +389,7 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
||||
"RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"), # noqa: E501
|
||||
"XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small"), # noqa: E501
|
||||
# [Multimodal]
|
||||
"CLIPModel": _HfExamplesInfo("openai/clip-vit-base-patch32"),
|
||||
"LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"),
|
||||
"Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full",
|
||||
trust_remote_code=True),
|
||||
@ -687,7 +688,11 @@ class HfExampleModels:
|
||||
return self.hf_models.keys()
|
||||
|
||||
def get_hf_info(self, model_arch: str) -> _HfExamplesInfo:
|
||||
return self.hf_models[model_arch]
|
||||
try:
|
||||
return self.hf_models[model_arch]
|
||||
except KeyError:
|
||||
raise ValueError(f"No example model defined for {model_arch}; "
|
||||
f"please update this file.") from None
|
||||
|
||||
def find_hf_info(self, model_id: str) -> _HfExamplesInfo:
|
||||
for info in self.hf_models.values():
|
||||
@ -699,7 +704,8 @@ class HfExampleModels:
|
||||
if any(extra == model_id for extra in info.extras.values()):
|
||||
return info
|
||||
|
||||
raise ValueError(f"No example model defined for {model_id}")
|
||||
raise ValueError(f"No example model defined for {model_id}; "
|
||||
f"please update this file.")
|
||||
|
||||
|
||||
HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS)
|
||||
|
||||
@ -417,12 +417,16 @@ class MultiHeadAttention(nn.Module):
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
):
|
||||
# This has no effect, it is only here to make it easier to swap
|
||||
# between Attention and MultiHeadAttention
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = scale
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
self.layer_name = prefix
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0, \
|
||||
f"num_heads ({self.num_heads}) is not " \
|
||||
|
||||
@ -351,7 +351,7 @@ class BertModel(nn.Module, SupportsQuant):
|
||||
prefix=f"{prefix}.encoder")
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embeddings(input_ids)
|
||||
return self.embeddings.word_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -1,28 +1,63 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Minimal implementation of CLIPVisionModel intended to be only used
|
||||
within a vision language model."""
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Union
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import cached_property
|
||||
from typing import Annotated, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import CLIPVisionConfig
|
||||
from transformers import (BatchFeature, CLIPConfig, CLIPProcessor,
|
||||
CLIPTextConfig, CLIPVisionConfig)
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import SupportsQuant
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputs, MultiModalKwargsItems,
|
||||
MultiModalUUIDDict)
|
||||
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
||||
MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptIndexTargets,
|
||||
PromptReplacement, PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal
|
||||
from .interfaces_base import default_pooling_type
|
||||
from .utils import AutoWeightsLoader, maybe_prefix
|
||||
from .vision import (VisionEncoderInfo, VisionFeatureSelectStrategy,
|
||||
VisionFeatureSelectStrategyStr,
|
||||
get_num_selected_vision_tokens,
|
||||
resolve_visual_encoder_outputs)
|
||||
|
||||
|
||||
class CLIPImagePixelInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- bn: Batch size * number of images
|
||||
- c: Number of channels (3)
|
||||
- h: Height of each image
|
||||
- w: Width of each image
|
||||
"""
|
||||
type: Literal["pixel_values"]
|
||||
data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
|
||||
|
||||
|
||||
class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]):
|
||||
|
||||
def get_num_image_tokens(
|
||||
@ -45,7 +80,214 @@ class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]):
|
||||
return image_size // patch_size
|
||||
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
|
||||
_POOLING_TYPE_TO_STRATEGY: dict[str, VisionFeatureSelectStrategyStr] = {
|
||||
"MEAN": "full",
|
||||
"ALL": "full",
|
||||
"CLS": "class",
|
||||
# This lets us use the same pooling type for both text and image
|
||||
"LAST": "class",
|
||||
}
|
||||
|
||||
|
||||
def _get_vision_feature_select_strategy(pooling_type: str):
|
||||
try:
|
||||
return _POOLING_TYPE_TO_STRATEGY[pooling_type]
|
||||
except KeyError:
|
||||
raise ValueError(f"No feature selection strategy is defined for "
|
||||
f"pooling_type: {pooling_type!r}") from None
|
||||
|
||||
|
||||
class CLIPProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_hf_config(self):
|
||||
return self.ctx.get_hf_config(CLIPConfig)
|
||||
|
||||
def get_vision_encoder_info(self):
|
||||
return CLIPEncoderInfo(self.get_hf_config())
|
||||
|
||||
def get_hf_processor(self, **kwargs: object):
|
||||
return self.ctx.get_hf_processor(CLIPProcessor, **kwargs)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": 1}
|
||||
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
vision_encoder_info = self.get_vision_encoder_info()
|
||||
|
||||
pooler_config = self.ctx.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
return get_num_selected_vision_tokens(
|
||||
vision_encoder_info.get_num_image_tokens(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
),
|
||||
_get_vision_feature_select_strategy(pooler_config.pooling_type),
|
||||
)
|
||||
|
||||
def get_image_size_with_most_features(self) -> ImageSize:
|
||||
vision_encoder_info = self.get_vision_encoder_info()
|
||||
width = height = vision_encoder_info.get_image_size()
|
||||
return ImageSize(width=width, height=height)
|
||||
|
||||
def get_max_image_tokens(self) -> int:
|
||||
target_width, target_height = self.get_image_size_with_most_features()
|
||||
|
||||
return self.get_num_image_tokens(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
)
|
||||
|
||||
|
||||
class CLIPDummyInputsBuilder(BaseDummyInputsBuilder[CLIPProcessingInfo]):
|
||||
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
return ""
|
||||
|
||||
def get_dummy_mm_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images,
|
||||
overrides=image_overrides)
|
||||
}
|
||||
|
||||
|
||||
class CLIPMultiModalProcessor(BaseMultiModalProcessor[CLIPProcessingInfo]):
|
||||
|
||||
@cached_property
|
||||
def image_token_id(self) -> int:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
dummy_token_id = 0
|
||||
|
||||
assert dummy_token_id not in tokenizer.all_special_ids
|
||||
|
||||
return dummy_token_id
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt: Union[str, list[int]],
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
tokenization_kwargs: Optional[Mapping[str, object]] = None,
|
||||
*,
|
||||
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
||||
) -> MultiModalInputs:
|
||||
if prompt and mm_data:
|
||||
raise ValueError(
|
||||
"CLIP accepts text-only or image-only inputs, not both! "
|
||||
"Image-only inputs means passing an image with an empty text "
|
||||
"prompt.")
|
||||
|
||||
if mm_data:
|
||||
# For multi-modal data, the prompt after processing should
|
||||
# only contain the dummy image tokens
|
||||
tokenization_kwargs = {
|
||||
**(tokenization_kwargs or {}),
|
||||
"add_special_tokens": False,
|
||||
}
|
||||
|
||||
return super().apply(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
|
||||
def _hf_processor_applies_updates(
|
||||
self,
|
||||
prompt_text: str,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
tokenization_kwargs: Mapping[str, object],
|
||||
) -> bool:
|
||||
return False
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargsItems,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
image_token_id = self.image_token_id
|
||||
|
||||
def get_replacement(item_idx: int):
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
image_size = images.get_image_size(item_idx)
|
||||
|
||||
num_image_tokens = self.info.get_num_image_tokens(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
)
|
||||
return [image_token_id] * num_image_tokens
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=PromptIndexTargets.start(),
|
||||
replacement=get_replacement,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# Adapted from: https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/models/clip/modeling_clip.py
|
||||
class CLIPTextEmbeddings(nn.Module):
|
||||
|
||||
def __init__(self, config: CLIPTextConfig):
|
||||
super().__init__()
|
||||
|
||||
embed_dim = config.hidden_size
|
||||
|
||||
self.token_embedding = VocabParallelEmbedding(config.vocab_size,
|
||||
embed_dim)
|
||||
self.position_embedding = VocabParallelEmbedding(
|
||||
config.max_position_embeddings, embed_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
position_ids: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if inputs_embeds is None:
|
||||
if input_ids is None:
|
||||
raise ValueError(
|
||||
"Either `input_ids` or `input_embeds` must be provided")
|
||||
|
||||
inputs_embeds = self.token_embedding(input_ids)
|
||||
|
||||
position_embeddings = self.position_embedding(position_ids)
|
||||
embeddings = inputs_embeds + position_embeddings
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
class CLIPVisionEmbeddings(nn.Module):
|
||||
|
||||
def __init__(self, config: CLIPVisionConfig):
|
||||
@ -89,15 +331,17 @@ class CLIPVisionEmbeddings(nn.Module):
|
||||
|
||||
|
||||
class CLIPAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: CLIPVisionConfig,
|
||||
config: Union[CLIPTextConfig, CLIPVisionConfig],
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
):
|
||||
attn_cls: Union[type[Attention], type[MultiHeadAttention]],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
@ -127,8 +371,12 @@ class CLIPAttention(nn.Module):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||
|
||||
self.attn = MultiHeadAttention(self.num_heads_per_partition,
|
||||
self.head_dim, self.scale)
|
||||
self.attn = attn_cls(
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim,
|
||||
self.scale,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -148,7 +396,7 @@ class CLIPMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: CLIPVisionConfig,
|
||||
config: Union[CLIPTextConfig, CLIPVisionConfig],
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
@ -178,15 +426,18 @@ class CLIPEncoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: CLIPVisionConfig,
|
||||
config: Union[CLIPTextConfig, CLIPVisionConfig],
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
attn_cls: Union[type[Attention], type[MultiHeadAttention]],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.self_attn = CLIPAttention(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
attn_cls=attn_cls,
|
||||
)
|
||||
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
@ -223,10 +474,12 @@ class CLIPEncoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: CLIPVisionConfig,
|
||||
config: Union[CLIPTextConfig, CLIPVisionConfig],
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
num_hidden_layers_override: Optional[int] = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
attn_cls: Union[type[Attention], type[MultiHeadAttention]],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -239,12 +492,15 @@ class CLIPEncoder(nn.Module):
|
||||
self.layers = nn.ModuleList([
|
||||
CLIPEncoderLayer(config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}")
|
||||
prefix=f"{prefix}.layers.{layer_idx}",
|
||||
attn_cls=attn_cls)
|
||||
for layer_idx in range(num_hidden_layers)
|
||||
])
|
||||
|
||||
def forward(
|
||||
self, inputs_embeds: torch.Tensor, return_all_hidden_states: bool
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
return_all_hidden_states: bool,
|
||||
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
||||
hidden_states_pool = [inputs_embeds]
|
||||
hidden_states = inputs_embeds
|
||||
@ -260,6 +516,87 @@ class CLIPEncoder(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CLIPTextTransformer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: CLIPTextConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
embed_dim = config.hidden_size
|
||||
|
||||
self.embeddings = CLIPTextEmbeddings(config)
|
||||
|
||||
self.encoder = CLIPEncoder(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.encoder",
|
||||
attn_cls=Attention,
|
||||
)
|
||||
|
||||
self.final_layer_norm = nn.LayerNorm(
|
||||
embed_dim,
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embeddings.token_embedding(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
position_ids: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
last_hidden_state = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
return_all_hidden_states=False,
|
||||
)
|
||||
last_hidden_state = self.final_layer_norm(last_hidden_state)
|
||||
|
||||
return last_hidden_state
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class CLIPVisionTransformer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@ -287,6 +624,7 @@ class CLIPVisionTransformer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
prefix=f"{prefix}.encoder",
|
||||
attn_cls=MultiHeadAttention,
|
||||
)
|
||||
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
@ -306,6 +644,14 @@ class CLIPVisionTransformer(nn.Module):
|
||||
else:
|
||||
self.post_layernorm = None
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
@ -335,47 +681,6 @@ class CLIPVisionTransformer(nn.Module):
|
||||
|
||||
return encoder_outputs
|
||||
|
||||
|
||||
class CLIPVisionModel(nn.Module, SupportsQuant):
|
||||
config_class = CLIPVisionConfig
|
||||
main_input_name = "pixel_values"
|
||||
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: CLIPVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
*,
|
||||
num_hidden_layers_override: Optional[int] = None,
|
||||
require_post_norm: Optional[bool] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.vision_model = CLIPVisionTransformer(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
require_post_norm=require_post_norm,
|
||||
prefix=f"{prefix}.vision_model")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
select_layers: Optional[list[int]] = None,
|
||||
feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None,
|
||||
) -> torch.Tensor:
|
||||
return self.vision_model(
|
||||
pixel_values,
|
||||
select_layers=select_layers,
|
||||
feature_select_strategy=feature_select_strategy,
|
||||
)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
# (TODO) Add prefix argument for filtering out weights to be loaded
|
||||
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
@ -386,17 +691,17 @@ class CLIPVisionModel(nn.Module, SupportsQuant):
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
layer_count = len(self.vision_model.encoder.layers)
|
||||
layer_count = len(self.encoder.layers)
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
# post_layernorm is not needed in CLIPVisionModel
|
||||
if (name.startswith("vision_model.post_layernorm")
|
||||
and self.vision_model.post_layernorm is None):
|
||||
if (name.startswith("post_layernorm")
|
||||
and self.post_layernorm is None):
|
||||
continue
|
||||
|
||||
# omit layers when num_hidden_layers_override is set
|
||||
if name.startswith("vision_model.encoder.layers"):
|
||||
layer_idx = int(name.split(".")[3])
|
||||
if name.startswith("encoder.layers"):
|
||||
layer_idx = int(name.split(".")[2])
|
||||
if layer_idx >= layer_count:
|
||||
continue
|
||||
|
||||
@ -416,3 +721,233 @@ class CLIPVisionModel(nn.Module, SupportsQuant):
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class CLIPVisionModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: CLIPVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
*,
|
||||
num_hidden_layers_override: Optional[int] = None,
|
||||
require_post_norm: Optional[bool] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.vision_model = CLIPVisionTransformer(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
require_post_norm=require_post_norm,
|
||||
prefix=f"{prefix}.vision_model",
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
select_layers: Optional[list[int]] = None,
|
||||
feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None,
|
||||
) -> torch.Tensor:
|
||||
return self.vision_model(
|
||||
pixel_values,
|
||||
select_layers=select_layers,
|
||||
feature_select_strategy=feature_select_strategy,
|
||||
)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.vision_model.dtype
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.vision_model.device
|
||||
|
||||
|
||||
# Assume EOS token corresponds to LAST token in text model
|
||||
@default_pooling_type("LAST")
|
||||
@MULTIMODAL_REGISTRY.register_processor(CLIPMultiModalProcessor,
|
||||
info=CLIPProcessingInfo,
|
||||
dummy_inputs=CLIPDummyInputsBuilder)
|
||||
class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
|
||||
|
||||
is_pooling_model = True
|
||||
|
||||
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
|
||||
merge_by_field_config = True
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return None
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config: CLIPConfig = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
text_config = config.text_config
|
||||
vision_config = config.vision_config
|
||||
|
||||
self.projection_dim = config.projection_dim
|
||||
self.text_embed_dim = text_config.hidden_size
|
||||
self.vision_embed_dim = vision_config.hidden_size
|
||||
|
||||
self.text_model = CLIPTextTransformer(
|
||||
text_config,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "text_model"),
|
||||
)
|
||||
self.vision_model = CLIPVisionTransformer(
|
||||
vision_config,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "vision_model"),
|
||||
)
|
||||
|
||||
self.visual_projection = nn.Linear(
|
||||
self.vision_embed_dim,
|
||||
self.projection_dim,
|
||||
bias=False,
|
||||
)
|
||||
self.text_projection = nn.Linear(
|
||||
self.text_embed_dim,
|
||||
self.projection_dim,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
self.pooler_config = pooler_config
|
||||
|
||||
self.pooler = DispatchPooler({
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"embed": Pooler.for_embed(pooler_config),
|
||||
})
|
||||
|
||||
# Assumes that self.forward is called after self.get_input_embeddings
|
||||
self._is_text_input = True
|
||||
|
||||
def get_text_features(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
position_ids: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
pooled_output = self.text_model(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
text_features = self.text_projection(pooled_output)
|
||||
|
||||
return text_features
|
||||
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None,
|
||||
) -> torch.Tensor:
|
||||
if feature_select_strategy is None:
|
||||
feature_select_strategy = _get_vision_feature_select_strategy(
|
||||
self.pooler_config.pooling_type)
|
||||
|
||||
pooled_output = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
select_layers=None,
|
||||
feature_select_strategy=feature_select_strategy,
|
||||
)
|
||||
|
||||
image_features = self.visual_projection(pooled_output)
|
||||
|
||||
return image_features
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[CLIPImagePixelInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
if pixel_values is None:
|
||||
return None
|
||||
|
||||
expected_h = expected_w = self.config.vision_config.image_size
|
||||
return CLIPImagePixelInputs(type="pixel_values",
|
||||
data=pixel_values,
|
||||
resolve_bindings={
|
||||
"h": expected_h,
|
||||
"w": expected_w
|
||||
})
|
||||
|
||||
def _process_image_inputs(self,
|
||||
inputs: CLIPImagePixelInputs) -> torch.Tensor:
|
||||
pixel_values = inputs["data"]
|
||||
|
||||
return self.get_image_features(pixel_values)
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.text_model
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
*,
|
||||
is_multimodal: Optional[torch.Tensor] = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
self._is_text_input = (multimodal_embeddings is None
|
||||
or len(multimodal_embeddings) == 0)
|
||||
|
||||
# This is to satisfy the type checker for each overload
|
||||
if multimodal_embeddings is None or is_multimodal is None:
|
||||
return super().get_input_embeddings(input_ids)
|
||||
|
||||
return super().get_input_embeddings(
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
def get_multimodal_embeddings(self,
|
||||
**kwargs: object) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return []
|
||||
|
||||
vision_embeddings = self._process_image_inputs(image_input)
|
||||
return vision_embeddings
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
) -> torch.Tensor:
|
||||
if intermediate_tensors is not None:
|
||||
raise RuntimeError("PP is not supported for this model")
|
||||
|
||||
# Multimodal inputs
|
||||
if not self._is_text_input:
|
||||
return inputs_embeds
|
||||
|
||||
# Text inputs
|
||||
return self.get_text_features(input_ids=input_ids,
|
||||
position_ids=positions,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_substrs=[".position_ids"],
|
||||
ignore_unexpected_prefixes=["logit_scale."],
|
||||
)
|
||||
|
||||
return loader.load_weights(weights)
|
||||
|
||||
@ -187,6 +187,7 @@ _EMBEDDING_MODELS = {
|
||||
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
|
||||
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
|
||||
# [Multimodal]
|
||||
"CLIPModel": ("clip", "CLIPEmbeddingModel"),
|
||||
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
|
||||
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
||||
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
|
||||
|
||||
@ -92,8 +92,10 @@ def get_vit_attn_backend(head_size: int, dtype: torch.dtype) -> _Backend:
|
||||
return current_platform.get_vit_attn_backend(head_size, dtype)
|
||||
|
||||
|
||||
VisionFeatureSelectStrategyStr = Literal["class", "default", "full"]
|
||||
|
||||
VisionFeatureSelectStrategy = Union[
|
||||
Literal["class", "default", "full"],
|
||||
VisionFeatureSelectStrategyStr,
|
||||
Callable[[torch.Tensor], torch.Tensor],
|
||||
]
|
||||
|
||||
@ -106,7 +108,7 @@ def _get_vision_feature_selector(
|
||||
|
||||
# https://github.com/huggingface/transformers/blob/cd74917ffc3e8f84e4a886052c5ab32b7ac623cc/src/transformers/models/clip/modeling_clip.py#L762
|
||||
if strategy == "class":
|
||||
return lambda feats: feats[:, 0, :]
|
||||
return lambda feats: feats[:, :1, :]
|
||||
|
||||
# https://github.com/huggingface/transformers/blob/4a02bc7004285bdb12cc033e87ad2578ce2fa900/src/transformers/models/llava/modeling_llava.py#L196
|
||||
if strategy == "default":
|
||||
|
||||
@ -33,6 +33,7 @@ def _get_minicpmv_chat_template_fallback(
|
||||
# yapf: disable
|
||||
_MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK: dict[str, ChatTemplatePath] = {
|
||||
"blip-2": CHAT_TEMPLATES_DIR / "template_blip2.jinja",
|
||||
"clip": CHAT_TEMPLATES_DIR / "template_basic.jinja",
|
||||
"chameleon": CHAT_TEMPLATES_DIR / "template_basic.jinja",
|
||||
"deepseek_vl_v2": CHAT_TEMPLATES_DIR / "template_deepseek_vl2.jinja",
|
||||
"fuyu": CHAT_TEMPLATES_DIR / "template_fuyu.jinja",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user