mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 09:06:02 +08:00
[Model] Siglip Embedding Support (#27324)
Signed-off-by: piood <2477084691@qq.com>
This commit is contained in:
parent
51dd14ac2b
commit
0552cfb195
@ -800,12 +800,13 @@ These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) A
|
||||
|
||||
The following table lists those that are tested in vLLM.
|
||||
|
||||
| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
|
||||
|--------------|--------|--------|-------------------|----------------------|---------------------------|
|
||||
| `CLIPModel` | CLIP | T / I | `openai/clip-vit-base-patch32`, `openai/clip-vit-large-patch14`, etc. | | |
|
||||
| `LlavaNextForConditionalGeneration`<sup>C</sup> | LLaVA-NeXT-based | T / I | `royokong/e5-v` | | ✅︎ |
|
||||
| `Phi3VForCausalLM`<sup>C</sup> | Phi-3-Vision-based | T + I | `TIGER-Lab/VLM2Vec-Full` | | ✅︎ |
|
||||
| `*ForConditionalGeneration`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | \* | N/A | \* | \* |
|
||||
| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|
||||
|--------------|--------|--------|-------------------|----------------------|---------------------------|---------------------|
|
||||
| `CLIPModel` | CLIP | T / I | `openai/clip-vit-base-patch32`, `openai/clip-vit-large-patch14`, etc. | | | ✅︎ |
|
||||
| `LlavaNextForConditionalGeneration`<sup>C</sup> | LLaVA-NeXT-based | T / I | `royokong/e5-v` | | ✅︎ | ✅︎ |
|
||||
| `Phi3VForCausalLM`<sup>C</sup> | Phi-3-Vision-based | T + I | `TIGER-Lab/VLM2Vec-Full` | | ✅︎ | ✅︎ |
|
||||
| `SiglipModel` | SigLIP | T / I | `google/siglip-base-patch16-224` | | | ✅︎ |
|
||||
| `*ForConditionalGeneration`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | \* | N/A | \* | \* | \* |
|
||||
|
||||
<sup>C</sup> Automatically converted into an embedding model via `--convert embed`. ([details](./pooling_models.md#model-conversion))
|
||||
\* Feature support is the same as that of the original model.
|
||||
|
||||
@ -110,6 +110,53 @@ def run_e5_v(query: Query) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
def run_jinavl_reranker(query: Query) -> ModelRequestData:
|
||||
if query["modality"] != "text+images":
|
||||
raise ValueError(f"Unsupported query modality: '{query['modality']}'")
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model="jinaai/jina-reranker-m0",
|
||||
runner="pooling",
|
||||
max_model_len=32768,
|
||||
trust_remote_code=True,
|
||||
mm_processor_kwargs={
|
||||
"min_pixels": 3136,
|
||||
"max_pixels": 602112,
|
||||
},
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
query=query["text"],
|
||||
documents=query["image"],
|
||||
)
|
||||
|
||||
|
||||
def run_siglip(query: Query) -> ModelRequestData:
|
||||
if query["modality"] == "text":
|
||||
prompt = query["text"]
|
||||
image = None
|
||||
elif query["modality"] == "image":
|
||||
prompt = "" # For image input, make sure that the prompt text is empty
|
||||
image = query["image"]
|
||||
else:
|
||||
modality = query["modality"]
|
||||
raise ValueError(f"Unsupported query modality: '{modality}'")
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model="google/siglip-base-patch16-224",
|
||||
runner="pooling",
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
)
|
||||
|
||||
|
||||
def _get_vlm2vec_prompt_image(query: Query, image_token: str):
|
||||
if query["modality"] == "text":
|
||||
text = query["text"]
|
||||
@ -211,29 +258,6 @@ def run_vlm2vec_qwen2vl(query: Query) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
def run_jinavl_reranker(query: Query) -> ModelRequestData:
|
||||
if query["modality"] != "text+images":
|
||||
raise ValueError(f"Unsupported query modality: '{query['modality']}'")
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model="jinaai/jina-reranker-m0",
|
||||
runner="pooling",
|
||||
max_model_len=32768,
|
||||
trust_remote_code=True,
|
||||
mm_processor_kwargs={
|
||||
"min_pixels": 3136,
|
||||
"max_pixels": 602112,
|
||||
},
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
query=query["text"],
|
||||
documents=query["image"],
|
||||
)
|
||||
|
||||
|
||||
def get_query(modality: QueryModality):
|
||||
if modality == "text":
|
||||
return TextQuery(modality="text", text="A dog sitting in the grass")
|
||||
@ -328,9 +352,10 @@ def run_score(model: str, modality: QueryModality, seed: int | None):
|
||||
model_example_map = {
|
||||
"clip": run_clip,
|
||||
"e5_v": run_e5_v,
|
||||
"jinavl_reranker": run_jinavl_reranker,
|
||||
"siglip": run_siglip,
|
||||
"vlm2vec_phi3v": run_vlm2vec_phi3v,
|
||||
"vlm2vec_qwen2vl": run_vlm2vec_qwen2vl,
|
||||
"jinavl_reranker": run_jinavl_reranker,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -83,6 +83,109 @@ def run_clip(client: OpenAI, model: str):
|
||||
print("Text embedding output:", response.data[0].embedding)
|
||||
|
||||
|
||||
def run_dse_qwen2_vl(client: OpenAI, model: str):
|
||||
"""
|
||||
Start the server using:
|
||||
|
||||
vllm serve MrLight/dse-qwen2-2b-mrl-v1 \
|
||||
--runner pooling \
|
||||
--trust-remote-code \
|
||||
--max-model-len 8192 \
|
||||
--chat-template examples/template_dse_qwen2_vl.jinja
|
||||
"""
|
||||
response = create_chat_embeddings(
|
||||
client,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url,
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
encoding_format="float",
|
||||
)
|
||||
|
||||
print("Image embedding output:", response.data[0].embedding)
|
||||
|
||||
# MrLight/dse-qwen2-2b-mrl-v1 requires a placeholder image
|
||||
# of the minimum input size
|
||||
buffer = io.BytesIO()
|
||||
image_placeholder = Image.new("RGB", (56, 56))
|
||||
image_placeholder.save(buffer, "png")
|
||||
buffer.seek(0)
|
||||
image_placeholder = base64.b64encode(buffer.read()).decode("utf-8")
|
||||
response = create_chat_embeddings(
|
||||
client,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image_placeholder}",
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "Query: What is the weather like today?"},
|
||||
],
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
encoding_format="float",
|
||||
)
|
||||
|
||||
print("Text embedding output:", response.data[0].embedding)
|
||||
|
||||
|
||||
def run_siglip(client: OpenAI, model: str):
|
||||
"""
|
||||
Start the server using:
|
||||
|
||||
vllm serve google/siglip-base-patch16-224 \
|
||||
--runner pooling
|
||||
"""
|
||||
|
||||
response = create_chat_embeddings(
|
||||
client,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
],
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
encoding_format="float",
|
||||
)
|
||||
|
||||
print("Image embedding output:", response.data[0].embedding)
|
||||
|
||||
response = create_chat_embeddings(
|
||||
client,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "a photo of a cat"},
|
||||
],
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
encoding_format="float",
|
||||
)
|
||||
|
||||
print("Text embedding output:", response.data[0].embedding)
|
||||
|
||||
|
||||
def run_vlm2vec(client: OpenAI, model: str):
|
||||
"""
|
||||
Start the server using:
|
||||
@ -148,72 +251,11 @@ def run_vlm2vec(client: OpenAI, model: str):
|
||||
print("Text embedding output:", response.data[0].embedding)
|
||||
|
||||
|
||||
def run_dse_qwen2_vl(client: OpenAI, model: str):
|
||||
"""
|
||||
Start the server using:
|
||||
|
||||
vllm serve MrLight/dse-qwen2-2b-mrl-v1 \
|
||||
--runner pooling \
|
||||
--trust-remote-code \
|
||||
--max-model-len 8192 \
|
||||
--chat-template examples/template_dse_qwen2_vl.jinja
|
||||
"""
|
||||
response = create_chat_embeddings(
|
||||
client,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url,
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
encoding_format="float",
|
||||
)
|
||||
|
||||
print("Image embedding output:", response.data[0].embedding)
|
||||
|
||||
# MrLight/dse-qwen2-2b-mrl-v1 requires a placeholder image
|
||||
# of the minimum input size
|
||||
buffer = io.BytesIO()
|
||||
image_placeholder = Image.new("RGB", (56, 56))
|
||||
image_placeholder.save(buffer, "png")
|
||||
buffer.seek(0)
|
||||
image_placeholder = base64.b64encode(buffer.read()).decode("utf-8")
|
||||
response = create_chat_embeddings(
|
||||
client,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image_placeholder}",
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "Query: What is the weather like today?"},
|
||||
],
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
encoding_format="float",
|
||||
)
|
||||
|
||||
print("Text embedding output:", response.data[0].embedding)
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"clip": run_clip,
|
||||
"vlm2vec": run_vlm2vec,
|
||||
"dse_qwen2_vl": run_dse_qwen2_vl,
|
||||
"siglip": run_siglip,
|
||||
"vlm2vec": run_vlm2vec,
|
||||
}
|
||||
|
||||
|
||||
|
||||
137
tests/models/multimodal/pooling/test_siglip.py
Normal file
137
tests/models/multimodal/pooling/test_siglip.py
Normal file
@ -0,0 +1,137 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
from transformers import SiglipModel
|
||||
|
||||
from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
|
||||
from ...utils import check_embeddings_close
|
||||
|
||||
HF_TEXT_PROMPTS = [
|
||||
"a photo of a stop sign",
|
||||
"a photo of a cherry blossom",
|
||||
]
|
||||
|
||||
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts(
|
||||
{
|
||||
"stop_sign": "",
|
||||
"cherry_blossom": "",
|
||||
}
|
||||
)
|
||||
|
||||
MODELS = ["google/siglip-base-patch16-224"]
|
||||
|
||||
|
||||
def _run_test(
|
||||
hf_runner: type[HfRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
input_texts: list[str],
|
||||
input_images: PromptImageInput,
|
||||
model: str,
|
||||
*,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
with vllm_runner(
|
||||
model, runner="pooling", dtype=dtype, enforce_eager=True, max_model_len=64
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.embed(input_texts, images=input_images)
|
||||
|
||||
with hf_runner(model, dtype=dtype, auto_cls=SiglipModel) as hf_model:
|
||||
all_inputs = hf_model.get_inputs(input_texts, images=input_images)
|
||||
|
||||
all_outputs = []
|
||||
for inputs in all_inputs:
|
||||
inputs = hf_model.wrap_device(inputs)
|
||||
|
||||
if "pixel_values" in inputs:
|
||||
pooled_output = hf_model.model.get_image_features(
|
||||
pixel_values=inputs.pixel_values,
|
||||
).squeeze(0)
|
||||
else:
|
||||
pooled_output = hf_model.model.get_text_features(
|
||||
input_ids=inputs.input_ids,
|
||||
).squeeze(0)
|
||||
|
||||
all_outputs.append(pooled_output.tolist())
|
||||
|
||||
hf_outputs = all_outputs
|
||||
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=hf_outputs,
|
||||
embeddings_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
def test_models_text(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
input_texts_images = [(text, None) for text in HF_TEXT_PROMPTS]
|
||||
input_texts = [text for text, _ in input_texts_images]
|
||||
input_images = [image for _, image in input_texts_images]
|
||||
|
||||
_run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
input_texts,
|
||||
input_images, # type: ignore
|
||||
model,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
def test_models_image(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
input_texts_images = [
|
||||
(text, asset.pil_image) for text, asset in zip(HF_IMAGE_PROMPTS, image_assets)
|
||||
]
|
||||
input_texts = [text for text, _ in input_texts_images]
|
||||
input_images = [image for _, image in input_texts_images]
|
||||
|
||||
_run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
input_texts,
|
||||
input_images,
|
||||
model,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
def test_models_text_image_no_crash(
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
texts = [HF_TEXT_PROMPTS[0]]
|
||||
images = [image_assets[0].pil_image]
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
runner="pooling",
|
||||
dtype=dtype,
|
||||
enforce_eager=True,
|
||||
max_model_len=64,
|
||||
) as vllm_model:
|
||||
with pytest.raises(ValueError, match="not both"):
|
||||
vllm_model.embed(texts, images=images)
|
||||
|
||||
vllm_model.embed(texts)
|
||||
vllm_model.embed([""], images=images)
|
||||
@ -471,6 +471,7 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
||||
"TIGER-Lab/VLM2Vec-Full", trust_remote_code=True
|
||||
),
|
||||
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"),
|
||||
"SiglipModel": _HfExamplesInfo("google/siglip-base-patch16-224"),
|
||||
"PrithviGeoSpatialMAE": _HfExamplesInfo(
|
||||
"ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
|
||||
dtype="float16",
|
||||
|
||||
@ -209,6 +209,7 @@ _EMBEDDING_MODELS = {
|
||||
),
|
||||
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
||||
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
|
||||
"SiglipModel": ("siglip", "SiglipEmbeddingModel"),
|
||||
# Technically Terratorch models work on images, both in
|
||||
# input and output. I am adding it here because it piggy-backs on embedding
|
||||
# models for the time being.
|
||||
|
||||
@ -4,13 +4,23 @@
|
||||
within a vision language model."""
|
||||
|
||||
import math
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Iterable, Mapping
|
||||
from functools import cached_property
|
||||
from typing import Annotated, Literal
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import SiglipVisionConfig
|
||||
from transformers import (
|
||||
BatchFeature,
|
||||
SiglipConfig,
|
||||
SiglipProcessor,
|
||||
SiglipTextConfig,
|
||||
SiglipVisionConfig,
|
||||
)
|
||||
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (
|
||||
@ -18,20 +28,232 @@ from vllm.model_executor.layers.linear import (
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
maybe_remap_kv_scale_name,
|
||||
)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalFieldConfig,
|
||||
MultiModalInputs,
|
||||
MultiModalKwargsItems,
|
||||
MultiModalUUIDDict,
|
||||
)
|
||||
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
|
||||
from vllm.multimodal.processing import (
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
PromptIndexTargets,
|
||||
PromptReplacement,
|
||||
PromptUpdate,
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsQuant
|
||||
from .interfaces_base import default_pooling_type
|
||||
from .utils import AutoWeightsLoader, maybe_prefix
|
||||
from .vision import (
|
||||
VisionEncoderInfo,
|
||||
VisionFeatureSelectStrategy,
|
||||
VisionFeatureSelectStrategyStr,
|
||||
get_num_selected_vision_tokens,
|
||||
resolve_visual_encoder_outputs,
|
||||
)
|
||||
|
||||
|
||||
class SiglipImagePixelInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- bn: Batch size * number of images
|
||||
- c: Number of channels (3)
|
||||
- h: Height of each image
|
||||
- w: Width of each image
|
||||
"""
|
||||
|
||||
type: Literal["pixel_values"]
|
||||
data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
|
||||
|
||||
|
||||
_POOLING_TYPE_TO_STRATEGY: dict[str, VisionFeatureSelectStrategyStr] = {
|
||||
"MEAN": "full",
|
||||
"ALL": "full",
|
||||
"CLS": "class",
|
||||
}
|
||||
|
||||
|
||||
def _get_vision_feature_select_strategy(
|
||||
pooling_type: str,
|
||||
) -> VisionFeatureSelectStrategyStr:
|
||||
try:
|
||||
return _POOLING_TYPE_TO_STRATEGY[pooling_type]
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"No feature selection strategy is defined for "
|
||||
f"pooling_type: {pooling_type!r}"
|
||||
) from None
|
||||
|
||||
|
||||
class SiglipProcessingInfo(BaseProcessingInfo):
|
||||
def get_hf_config(self):
|
||||
return self.ctx.get_hf_config(SiglipConfig)
|
||||
|
||||
def get_vision_encoder_info(self):
|
||||
return SiglipEncoderInfo(self.get_hf_config())
|
||||
|
||||
def get_hf_processor(self, **kwargs: object):
|
||||
return self.ctx.get_hf_processor(SiglipProcessor, **kwargs)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
return {"image": 1}
|
||||
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
vision_encoder_info = self.get_vision_encoder_info()
|
||||
|
||||
pooler_config = self.ctx.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
return get_num_selected_vision_tokens(
|
||||
vision_encoder_info.get_num_image_tokens(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
),
|
||||
_get_vision_feature_select_strategy(pooler_config.pooling_type),
|
||||
)
|
||||
|
||||
def get_image_size_with_most_features(self) -> ImageSize:
|
||||
vision_encoder_info = self.get_vision_encoder_info()
|
||||
width = height = vision_encoder_info.get_image_size()
|
||||
return ImageSize(width=width, height=height)
|
||||
|
||||
def get_max_image_tokens(self) -> int:
|
||||
target_width, target_height = self.get_image_size_with_most_features()
|
||||
|
||||
return self.get_num_image_tokens(
|
||||
image_width=target_width, image_height=target_height
|
||||
)
|
||||
|
||||
|
||||
class SiglipDummyInputsBuilder(BaseDummyInputsBuilder[SiglipProcessingInfo]):
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
return ""
|
||||
|
||||
def get_dummy_mm_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
target_width, target_height = self.info.get_image_size_with_most_features()
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
return {
|
||||
"image": self._get_dummy_images(
|
||||
width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images,
|
||||
overrides=image_overrides,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
class SiglipMultiModalProcessor(BaseMultiModalProcessor[SiglipProcessingInfo]):
|
||||
@cached_property
|
||||
def image_token_id(self) -> int:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
dummy_token_id = 0
|
||||
|
||||
assert dummy_token_id not in tokenizer.all_special_ids
|
||||
|
||||
return dummy_token_id
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt: str | list[int],
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
tokenization_kwargs: Mapping[str, object] | None = None,
|
||||
*,
|
||||
mm_uuids: MultiModalUUIDDict | None = None,
|
||||
) -> MultiModalInputs:
|
||||
if prompt and mm_data:
|
||||
raise ValueError(
|
||||
"Siglip accepts text-only or image-only inputs, not both! "
|
||||
"Image-only inputs means passing an image with an empty text "
|
||||
"prompt."
|
||||
)
|
||||
|
||||
if mm_data:
|
||||
# For multi-modal data, the prompt after processing should
|
||||
# only contain the image token
|
||||
tokenization_kwargs = {
|
||||
**(tokenization_kwargs or {}),
|
||||
"add_special_tokens": False,
|
||||
}
|
||||
|
||||
return super().apply(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
|
||||
def _hf_processor_applies_updates(
|
||||
self,
|
||||
prompt_text: str,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
tokenization_kwargs: Mapping[str, object],
|
||||
) -> bool:
|
||||
return False
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargsItems,
|
||||
) -> list[PromptUpdate]:
|
||||
image_token_id = self.image_token_id
|
||||
|
||||
def get_replacement(item_idx: int):
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
image_size = images.get_image_size(item_idx)
|
||||
|
||||
num_image_tokens = self.info.get_num_image_tokens(
|
||||
image_width=image_size.width, image_height=image_size.height
|
||||
)
|
||||
return [image_token_id] * num_image_tokens
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=PromptIndexTargets.start(),
|
||||
replacement=get_replacement,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
@ -151,8 +373,9 @@ class SiglipVisionEmbeddings(nn.Module):
|
||||
class SiglipAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: SiglipVisionConfig,
|
||||
config: SiglipVisionConfig | SiglipTextConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@ -195,12 +418,29 @@ class SiglipAttention(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
) -> tuple[torch.Tensor, None]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
qkv_states, _ = self.qkv_proj(hidden_states)
|
||||
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
|
||||
|
||||
needs_unsqueeze = query_states.ndim == 2
|
||||
if needs_unsqueeze:
|
||||
query_states, key_states, value_states = (
|
||||
query_states.unsqueeze(0),
|
||||
key_states.unsqueeze(0),
|
||||
value_states.unsqueeze(0),
|
||||
)
|
||||
|
||||
out = self.attn(query_states, key_states, value_states)
|
||||
|
||||
if needs_unsqueeze:
|
||||
out, query_states, key_states, value_states = (
|
||||
out.squeeze(0),
|
||||
query_states.squeeze(0),
|
||||
key_states.squeeze(0),
|
||||
value_states.squeeze(0),
|
||||
)
|
||||
|
||||
attn_output, _ = self.out_proj(out)
|
||||
|
||||
return attn_output, None
|
||||
@ -209,7 +449,7 @@ class SiglipAttention(nn.Module):
|
||||
class SiglipMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: SiglipVisionConfig,
|
||||
config: SiglipVisionConfig | SiglipTextConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
@ -249,8 +489,9 @@ class SiglipMLP(nn.Module):
|
||||
class SiglipEncoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: SiglipVisionConfig,
|
||||
config: SiglipVisionConfig | SiglipTextConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@ -291,9 +532,10 @@ class SiglipEncoderLayer(nn.Module):
|
||||
class SiglipEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: SiglipVisionConfig,
|
||||
config: SiglipVisionConfig | SiglipTextConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
num_hidden_layers_override: int | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@ -335,6 +577,76 @@ class SiglipEncoder(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SiglipTextTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: SiglipTextConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
embed_dim = config.hidden_size
|
||||
|
||||
self.embeddings = SiglipTextEmbeddings(config)
|
||||
|
||||
self.encoder = SiglipEncoder(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.encoder",
|
||||
)
|
||||
|
||||
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
self.head = nn.Linear(embed_dim, config.projection_size)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embeddings.token_embedding(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None,
|
||||
position_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embeddings(input_ids, position_ids, inputs_embeds)
|
||||
|
||||
last_hidden_state = self.encoder(
|
||||
inputs_embeds=hidden_states, return_all_hidden_states=False
|
||||
)
|
||||
|
||||
last_hidden_state = self.final_layer_norm(last_hidden_state)
|
||||
|
||||
return last_hidden_state
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class SiglipMultiheadAttentionPoolingHead(nn.Module):
|
||||
"""Multihead Attention Pooling."""
|
||||
|
||||
@ -357,8 +669,9 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||
batch_size = hidden_state.shape[0]
|
||||
probe = self.probe.repeat(batch_size, 1, 1)
|
||||
batch_size = hidden_state.size(0)
|
||||
|
||||
probe = self.probe.expand(batch_size, -1, -1)
|
||||
|
||||
hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
|
||||
|
||||
@ -367,7 +680,9 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
|
||||
hidden_state = self.mlp(hidden_state)
|
||||
hidden_state += residual
|
||||
|
||||
return hidden_state[:, 0]
|
||||
pooled = hidden_state[:, 0]
|
||||
|
||||
return pooled.unsqueeze(1)
|
||||
|
||||
|
||||
class SiglipVisionTransformer(nn.Module):
|
||||
@ -420,6 +735,14 @@ class SiglipVisionTransformer(nn.Module):
|
||||
prefix=f"{prefix}.head",
|
||||
)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
@ -432,7 +755,6 @@ class SiglipVisionTransformer(nn.Module):
|
||||
pixel_values,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
)
|
||||
|
||||
# Produces either the last layer output or all of the hidden states,
|
||||
# depending on if we have select_layers or not
|
||||
encoder_outputs = self.encoder(
|
||||
@ -440,21 +762,60 @@ class SiglipVisionTransformer(nn.Module):
|
||||
return_all_hidden_states=select_layers is not None,
|
||||
)
|
||||
|
||||
# Handle post-norm (if applicable) and stacks feature layers if needed
|
||||
if self.post_layernorm is not None:
|
||||
encoder_outputs = self.post_layernorm(encoder_outputs)
|
||||
|
||||
if self.use_head:
|
||||
encoder_outputs = self.head(encoder_outputs)
|
||||
|
||||
# stacks feature layers if needed
|
||||
encoder_outputs = resolve_visual_encoder_outputs(
|
||||
encoder_outputs,
|
||||
self.post_layernorm,
|
||||
None,
|
||||
select_layers=select_layers,
|
||||
max_possible_layers=self.config.num_hidden_layers,
|
||||
feature_select_strategy=feature_select_strategy,
|
||||
)
|
||||
|
||||
# TODO: add this back when pooled_output is used in inference.
|
||||
# if self.use_head:
|
||||
# pooled_output = self.head(encoder_outputs)
|
||||
|
||||
return encoder_outputs
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
layer_count = len(self.encoder.layers)
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
# post_layernorm is not needed in SiglipVisionTransformer
|
||||
if name.startswith("post_layernorm") and self.post_layernorm is None:
|
||||
continue
|
||||
|
||||
# omit layers when num_hidden_layers_override is set
|
||||
if name.startswith("encoder.layers"):
|
||||
layer_idx = int(name.split(".")[2])
|
||||
if layer_idx >= layer_count:
|
||||
continue
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class SiglipVisionModel(nn.Module):
|
||||
config_class = SiglipVisionConfig
|
||||
@ -484,7 +845,11 @@ class SiglipVisionModel(nn.Module):
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.get_input_embeddings().weight.dtype
|
||||
return self.vision_model.dtype
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.vision_model.device
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -555,3 +920,214 @@ class SiglipVisionModel(nn.Module):
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
# Adapted from: https://github.com/huggingface/transformers/blob/v4.54.1/src/transformers/models/siglip/modeling_siglip.py#L200
|
||||
class SiglipTextEmbeddings(nn.Module):
|
||||
def __init__(self, config: SiglipTextConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.token_embedding = VocabParallelEmbedding(
|
||||
config.vocab_size, config.hidden_size
|
||||
)
|
||||
|
||||
self.position_embedding = VocabParallelEmbedding(
|
||||
config.max_position_embeddings, config.hidden_size
|
||||
)
|
||||
|
||||
self.register_buffer(
|
||||
"position_ids",
|
||||
torch.arange(config.max_position_embeddings).expand((1, -1)),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None,
|
||||
position_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.token_embedding(input_ids)
|
||||
|
||||
position_embeddings = self.position_embedding(position_ids)
|
||||
embeddings = inputs_embeds + position_embeddings
|
||||
return embeddings
|
||||
|
||||
|
||||
# Assume EOS token corresponds to CLS token in text model
|
||||
@default_pooling_type("CLS")
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
SiglipMultiModalProcessor,
|
||||
info=SiglipProcessingInfo,
|
||||
dummy_inputs=SiglipDummyInputsBuilder,
|
||||
)
|
||||
class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
|
||||
is_pooling_model = True
|
||||
|
||||
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
|
||||
merge_by_field_config = True
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
||||
if modality.startswith("image"):
|
||||
return None
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config: SiglipConfig = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
if hasattr(config, "num_labels"):
|
||||
config.num_labels = 0
|
||||
|
||||
text_config = config.text_config
|
||||
vision_config = config.vision_config
|
||||
|
||||
self.text_embed_dim = text_config.hidden_size
|
||||
self.vision_embed_dim = vision_config.hidden_size
|
||||
|
||||
self.text_model = SiglipTextTransformer(
|
||||
text_config,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "text_model"),
|
||||
)
|
||||
self.vision_model = SiglipVisionTransformer(
|
||||
vision_config,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "vision_model"),
|
||||
)
|
||||
|
||||
self.text_projection_size = text_config.projection_size
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
self.pooler_config = pooler_config
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"token_embed": Pooler.for_token_embed(pooler_config),
|
||||
"embed": Pooler.for_embed(pooler_config),
|
||||
}
|
||||
)
|
||||
|
||||
self._is_text_input = True
|
||||
|
||||
def get_text_features(
|
||||
self,
|
||||
input_ids: torch.Tensor | None,
|
||||
position_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
last_hidden_state = self.text_model(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
text_features = self.text_model.head(last_hidden_state)
|
||||
# Flip to extract CLS token (first token after reversal) for pooling
|
||||
text_features = text_features.flip(0)
|
||||
return text_features
|
||||
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
feature_select_strategy: VisionFeatureSelectStrategy | None = None,
|
||||
) -> torch.Tensor:
|
||||
if feature_select_strategy is None:
|
||||
feature_select_strategy = _get_vision_feature_select_strategy(
|
||||
self.pooler_config.pooling_type
|
||||
)
|
||||
|
||||
pooled_output = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
select_layers=None,
|
||||
feature_select_strategy=feature_select_strategy,
|
||||
)
|
||||
|
||||
return pooled_output
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object
|
||||
) -> SiglipImagePixelInputs | None:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
if pixel_values is None:
|
||||
return None
|
||||
|
||||
expected_h = expected_w = self.config.vision_config.image_size
|
||||
return SiglipImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=pixel_values,
|
||||
resolve_bindings={"h": expected_h, "w": expected_w},
|
||||
)
|
||||
|
||||
def _process_image_inputs(self, inputs: SiglipImagePixelInputs) -> torch.Tensor:
|
||||
pixel_values = inputs["data"]
|
||||
|
||||
return self.get_image_features(pixel_values)
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.text_model
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
*,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
self._is_text_input = (
|
||||
multimodal_embeddings is None or len(multimodal_embeddings) == 0
|
||||
)
|
||||
|
||||
if multimodal_embeddings is None or is_multimodal is None:
|
||||
return super().get_input_embeddings(input_ids)
|
||||
|
||||
return super().get_input_embeddings(
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return []
|
||||
|
||||
vision_embeddings = self._process_image_inputs(image_input)
|
||||
return vision_embeddings
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
**kwargs: object,
|
||||
) -> torch.Tensor:
|
||||
if intermediate_tensors is not None:
|
||||
raise RuntimeError("PP is not supported for this model")
|
||||
|
||||
# Multimodal inputs (image embeddings)
|
||||
if not self._is_text_input:
|
||||
return inputs_embeds
|
||||
|
||||
return self.get_text_features(input_ids, positions, inputs_embeds)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_substrs=[".position_ids"],
|
||||
ignore_unexpected_prefixes=["logit_scale.", "logit_bias."],
|
||||
)
|
||||
|
||||
return loader.load_weights(weights)
|
||||
|
||||
@ -31,14 +31,15 @@ def _get_minicpmv_chat_template_fallback(tokenizer_name_or_path: str) -> Path |
|
||||
|
||||
_MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK: dict[str, ChatTemplatePath] = {
|
||||
"blip-2": CHAT_TEMPLATES_DIR / "template_blip2.jinja",
|
||||
"clip": CHAT_TEMPLATES_DIR / "template_basic.jinja",
|
||||
"chameleon": CHAT_TEMPLATES_DIR / "template_basic.jinja",
|
||||
"deepseek_vl_v2": CHAT_TEMPLATES_DIR / "template_deepseek_vl2.jinja",
|
||||
"clip": CHAT_TEMPLATES_DIR / "template_basic.jinja",
|
||||
"deepseek_ocr": CHAT_TEMPLATES_DIR / "template_deepseek_ocr.jinja",
|
||||
"deepseek_vl_v2": CHAT_TEMPLATES_DIR / "template_deepseek_vl2.jinja",
|
||||
"fuyu": CHAT_TEMPLATES_DIR / "template_fuyu.jinja",
|
||||
"minicpmv": _get_minicpmv_chat_template_fallback,
|
||||
"paligemma": CHAT_TEMPLATES_DIR / "template_basic.jinja",
|
||||
"qwen": _get_qwen_chat_template_fallback,
|
||||
"siglip": CHAT_TEMPLATES_DIR / "template_basic.jinja",
|
||||
}
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user