diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md
index 9959f7233e86..4b1f3e180ed5 100644
--- a/docs/source/models/supported_models.md
+++ b/docs/source/models/supported_models.md
@@ -715,6 +715,13 @@ See [this page](#generative-models) for more information on how to use generativ
*
* ✅︎
* ✅︎
+- * `Florence2ForConditionalGeneration`
+ * Florence-2
+ * T + I
+ * `microsoft/Florence-2-base`, `microsoft/Florence-2-large` etc.
+ *
+ *
+ *
- * `FuyuForCausalLM`
* Fuyu
* T + I
diff --git a/examples/offline_inference/florence2_inference.py b/examples/offline_inference/florence2_inference.py
index 58610b0fd2a5..27aceee43cbf 100644
--- a/examples/offline_inference/florence2_inference.py
+++ b/examples/offline_inference/florence2_inference.py
@@ -1,34 +1,45 @@
# SPDX-License-Identifier: Apache-2.0
-'''
+"""
Demonstrate prompting of text-to-text
encoder/decoder models, specifically Florence-2
-'''
+"""
# TODO(Isotr0py):
# Move to offline_inference/vision_language.py
# after porting vision backbone
from vllm import LLM, SamplingParams
-
-dtype = "float"
+from vllm.assets.image import ImageAsset
# Create a Florence-2 encoder/decoder model instance
llm = LLM(
- model="microsoft/Florence-2-base",
- tokenizer="facebook/bart-base",
- dtype=dtype,
+ model="microsoft/Florence-2-large",
+ tokenizer="facebook/bart-large",
+ max_num_seqs=8,
trust_remote_code=True,
)
prompts = [
- "
", "", "",
- "", "", "",
- "", "", ""
+ { # implicit prompt with task token
+ "prompt": "",
+ "multi_modal_data": {
+ "image": ImageAsset("stop_sign").pil_image
+ },
+ },
+ { # explicit encoder/decoder prompt
+ "encoder_prompt": {
+ "prompt": "Describe in detail what is shown in the image.",
+ "multi_modal_data": {
+ "image": ImageAsset("cherry_blossom").pil_image
+ },
+ },
+ "decoder_prompt": "",
+ },
]
# Create a sampling params object.
sampling_params = SamplingParams(
temperature=0,
top_p=1.0,
min_tokens=0,
- max_tokens=20,
+ max_tokens=128,
)
# Generate output tokens from the prompts. The output is a list of
@@ -38,9 +49,5 @@ outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
- prompt = output.prompt
- encoder_prompt = output.encoder_prompt
generated_text = output.outputs[0].text
- print(f"Encoder prompt: {encoder_prompt!r}, "
- f"Decoder prompt: {prompt!r}, "
- f"Generated text: {generated_text!r}")
+ print(f"Generated text: {generated_text!r}")
diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py
index 5f05389faf80..e2ec36211b86 100644
--- a/examples/offline_inference/vision_language.py
+++ b/examples/offline_inference/vision_language.py
@@ -82,6 +82,22 @@ def run_deepseek_vl2(question: str, modality: str):
return llm, prompt, stop_token_ids
+# Florence2
+def run_florence2(question: str, modality: str):
+ assert modality == "image"
+
+ llm = LLM(model="microsoft/Florence-2-large",
+ tokenizer="facebook/bart-large",
+ max_num_seqs=8,
+ trust_remote_code=True,
+ dtype="bfloat16",
+ disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
+
+ prompt = ""
+ stop_token_ids = None
+ return llm, prompt, stop_token_ids
+
+
# Fuyu
def run_fuyu(question: str, modality: str):
assert modality == "image"
@@ -571,6 +587,7 @@ model_example_map = {
"blip-2": run_blip2,
"chameleon": run_chameleon,
"deepseek_vl_v2": run_deepseek_vl2,
+ "florence2": run_florence2,
"fuyu": run_fuyu,
"glm4v": run_glm4v,
"h2ovl_chat": run_h2ovl,
diff --git a/tests/conftest.py b/tests/conftest.py
index dd339030e5e4..871f0b62c532 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -600,8 +600,8 @@ class HfRunner:
if images is not None and images[i] is not None:
processor_kwargs["images"] = images[i]
- encoder_input_ids = self.wrap_device(
- self.processor(**processor_kwargs).input_ids,
+ encoder_inputs = self.wrap_device(
+ self.processor(**processor_kwargs),
device=self.model.device.type,
)
@@ -615,13 +615,13 @@ class HfRunner:
)
output = self.model.generate(
- encoder_input_ids,
decoder_input_ids=decoder_input_ids,
use_cache=True,
do_sample=False,
max_new_tokens=max_tokens,
output_hidden_states=True,
return_dict_in_generate=True,
+ **encoder_inputs,
**kwargs,
)
diff --git a/tests/models/decoder_only/audio_language/test_ultravox.py b/tests/models/decoder_only/audio_language/test_ultravox.py
index d1f643a8fdb7..0ea17247028f 100644
--- a/tests/models/decoder_only/audio_language/test_ultravox.py
+++ b/tests/models/decoder_only/audio_language/test_ultravox.py
@@ -15,7 +15,7 @@ from ....conftest import HfRunner, VllmRunner
from ....utils import RemoteOpenAIServer
from ...utils import check_logprobs_close
-MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
+MODEL_NAME = "fixie-ai/ultravox-v0_4"
AudioTuple = Tuple[np.ndarray, int]
@@ -187,7 +187,7 @@ def run_multi_audio_test(
@pytest.mark.core_model
-@pytest.mark.parametrize("dtype", ["half"])
+@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("vllm_kwargs", [
diff --git a/tests/models/encoder_decoder/vision_language/test_florence2.py b/tests/models/encoder_decoder/vision_language/test_florence2.py
index a1d15679918b..de18deab11f6 100644
--- a/tests/models/encoder_decoder/vision_language/test_florence2.py
+++ b/tests/models/encoder_decoder/vision_language/test_florence2.py
@@ -1,52 +1,59 @@
# SPDX-License-Identifier: Apache-2.0
-from functools import partial
-from typing import List, Optional, Tuple, Type
+from typing import Optional, Type
import pytest
from PIL import Image
-from vllm.inputs.data import ExplicitEncoderDecoderPrompt
+from vllm.inputs.data import ExplicitEncoderDecoderPrompt, TextPrompt
+from vllm.multimodal.image import rescale_image_size
from vllm.sequence import SampleLogprobs
-from ....conftest import HfRunner, VllmRunner
+from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from ...utils import check_logprobs_close
-Florence2Prompt = partial(ExplicitEncoderDecoderPrompt,
- decoder_prompt=None,
- mm_processor_kwargs=None)
-
MODELS = ["microsoft/Florence-2-base"]
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
# Therefore, we borrow the BartTokenizer from the original Bart model
TOKENIZER = "facebook/bart-base"
-PROMPTS = [
- Florence2Prompt(encoder_prompt=""),
- Florence2Prompt(encoder_prompt=""),
- Florence2Prompt(encoder_prompt=""),
- Florence2Prompt(encoder_prompt=""),
- Florence2Prompt(encoder_prompt=""),
- Florence2Prompt(encoder_prompt=""),
- Florence2Prompt(encoder_prompt=""),
- Florence2Prompt(encoder_prompt=""),
- Florence2Prompt(encoder_prompt=""),
-]
+HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
+ "stop_sign":
+ "", # special task token
+ "cherry_blossom":
+ "Describe in detail what is shown in the image.",
+})
-def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
- Optional[SampleLogprobs]], ):
- """Sanitize vllm output to be comparable with hf output."""
- output_ids, output_str, out_logprobs = vllm_output
+def get_hf_images_prompts(
+ prompts_: list[ExplicitEncoderDecoderPrompt[str, TextPrompt]],
+) -> tuple[list[ExplicitEncoderDecoderPrompt[str, str]], list[Image.Image]]:
+ prompts, images = [], []
+ for prompt in prompts_:
+ encoder_prompt = prompt["encoder_prompt"]
+ prompts.append(
+ ExplicitEncoderDecoderPrompt(
+ encoder_prompt=encoder_prompt["prompt"],
+ decoder_prompt=None,
+ ))
+ images.append(encoder_prompt["multi_modal_data"]["image"])
+ return prompts, images
- hf_output_str = "" + output_str + ""
- return output_ids, hf_output_str, out_logprobs
+def hf_to_vllm_output(hf_output: tuple[list[int], str,
+ Optional[SampleLogprobs]]):
+ """Sanitize hf output to be comparable with vllm output."""
+ output_ids, output_str, out_logprobs = hf_output
+
+ output_str = output_str.replace("", "").replace("", "")
+ output_ids = [ids for ids in output_ids if ids not in [0, 2]]
+
+ return output_ids, output_str, out_logprobs
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
- prompts: List[ExplicitEncoderDecoderPrompt],
+ inputs: list[list[ExplicitEncoderDecoderPrompt]],
model: str,
*,
dtype: str,
@@ -56,46 +63,76 @@ def run_test(
distributed_executor_backend: Optional[str] = None,
) -> None:
with vllm_runner(model,
+ max_num_seqs=8,
tokenizer_name=TOKENIZER,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
- vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
- prompts, max_tokens, num_logprobs)
+ vllm_outputs_per_case = [
+ vllm_model.generate_encoder_decoder_greedy_logprobs(
+ prompts, max_tokens, num_logprobs=num_logprobs)
+ for prompts in inputs
+ ]
+
+ hf_inputs = [get_hf_images_prompts(prompts) for prompts in inputs]
- # Florence-2 processors require image inputs
- dummy_image = Image.new(mode="RGB", size=(2, 2))
with hf_runner(model, dtype=dtype, skip_tokenizer_init=True) as hf_model:
hf_model.model.get_output_embeddings = lambda: \
hf_model.model.language_model.lm_head
- hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit(
- prompts,
- max_tokens,
- num_logprobs,
- images=[dummy_image] * len(prompts),
- ))
+ hf_outputs_per_case = [
+ hf_model.generate_encoder_decoder_greedy_logprobs_limit(
+ prompts, max_tokens, num_logprobs=num_logprobs, images=images)
+ for prompts, images in hf_inputs
+ ]
- check_logprobs_close(
- outputs_0_lst=hf_outputs,
- outputs_1_lst=[
- vllm_to_hf_output(vllm_output) for vllm_output in vllm_outputs
- ],
- name_0="hf",
- name_1="vllm",
- )
+ for hf_outputs, vllm_outputs in zip(hf_outputs_per_case,
+ vllm_outputs_per_case):
+ check_logprobs_close(
+ outputs_0_lst=[hf_to_vllm_output(output) for output in hf_outputs],
+ outputs_1_lst=vllm_outputs,
+ name_0="hf",
+ name_1="vllm",
+ )
+@pytest.mark.core_model
@pytest.mark.parametrize("model", MODELS)
-@pytest.mark.parametrize("dtype", ["float", "bfloat16"])
+@pytest.mark.parametrize(
+ "size_factors",
+ [
+ # No image
+ [],
+ # Single-scale
+ [1.0],
+ # Single-scale, batched
+ [1.0, 1.0, 1.0],
+ # Multi-scale
+ [0.25, 0.5, 1.0],
+ ],
+)
+@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
-def test_models(hf_runner, vllm_runner, model, dtype, max_tokens,
- num_logprobs) -> None:
+def test_models(hf_runner: Type[HfRunner], vllm_runner: Type[VllmRunner],
+ image_assets: _ImageAssets, model: str,
+ size_factors: list[int], dtype: str, max_tokens: int,
+ num_logprobs: int) -> None:
+ images = [asset.pil_image for asset in image_assets]
+
+ inputs_per_image = [[
+ ExplicitEncoderDecoderPrompt(
+ encoder_prompt=TextPrompt(
+ prompt=prompt,
+ multi_modal_data={"image": rescale_image_size(image, factor)}),
+ decoder_prompt=None,
+ ) for factor in size_factors
+ ] for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
+
run_test(
hf_runner,
vllm_runner,
- PROMPTS,
+ inputs_per_image,
model,
dtype=dtype,
max_tokens=max_tokens,
diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py
index a84999cfbf4f..7534f0c97798 100644
--- a/tests/models/multimodal/processing/test_common.py
+++ b/tests/models/multimodal/processing/test_common.py
@@ -29,8 +29,8 @@ def _test_processing_correctness(
model_config = ModelConfig(
model_id,
task="auto",
- tokenizer=model_id,
- tokenizer_mode="auto",
+ tokenizer=model_info.tokenizer or model_id,
+ tokenizer_mode=model_info.tokenizer_mode,
trust_remote_code=model_info.trust_remote_code,
seed=0,
dtype="float16",
@@ -151,6 +151,7 @@ def _test_processing_correctness(
"Salesforce/blip2-opt-2.7b",
"facebook/chameleon-7b",
"deepseek-ai/deepseek-vl2-tiny",
+ "microsoft/Florence-2-base",
"adept/fuyu-8b",
"THUDM/glm-4v-9b",
"h2oai/h2ovl-mississippi-800m",
diff --git a/tests/models/registry.py b/tests/models/registry.py
index 8614baf18f3b..95bda0293498 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -193,11 +193,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
# [Encoder-decoder]
"BartModel": _HfExamplesInfo("facebook/bart-base"),
"BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"),
- # Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
- # Therefore, we borrow the BartTokenizer from the original Bart model
- "Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501
- tokenizer="facebook/bart-base",
- trust_remote_code=True), # noqa: E501
}
_EMBEDDING_EXAMPLE_MODELS = {
@@ -288,6 +283,11 @@ _MULTIMODAL_EXAMPLE_MODELS = {
extras={"v0.5": "fixie-ai/ultravox-v0_5-llama-3_2-1b"}, # noqa: E501
trust_remote_code=True),
# [Encoder-decoder]
+ # Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
+ # Therefore, we borrow the BartTokenizer from the original Bart model
+ "Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501
+ tokenizer="facebook/bart-base",
+ trust_remote_code=True), # noqa: E501
"MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501
"WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501
}
diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py
index 5d2a8cdcb97d..93452696dca5 100644
--- a/vllm/model_executor/models/bart.py
+++ b/vllm/model_executor/models/bart.py
@@ -588,8 +588,12 @@ class BartEncoder(nn.Module):
self.layernorm_embedding = nn.LayerNorm(embed_dim)
- def forward(self, input_ids: torch.Tensor,
- positions: torch.Tensor) -> torch.Tensor:
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
r"""
Args:
input_ids
@@ -602,7 +606,8 @@ class BartEncoder(nn.Module):
Decoder output torch.Tensor
"""
# retrieve input_ids and inputs_embeds
- inputs_embeds = self.embed_tokens(input_ids)
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
embed_pos = self.embed_positions(positions)
embed_pos = embed_pos.to(inputs_embeds.device)
@@ -661,9 +666,13 @@ class BartDecoder(nn.Module):
self.layernorm_embedding = nn.LayerNorm(config.d_model)
- def forward(self, decoder_input_ids: torch.Tensor,
- decoder_positions: torch.Tensor,
- encoder_hidden_states: Optional[torch.Tensor]) -> torch.Tensor:
+ def forward(
+ self,
+ decoder_input_ids: torch.Tensor,
+ decoder_positions: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor],
+ inputs_embeds: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
r"""
Args:
decoder_input_ids
@@ -677,8 +686,10 @@ class BartDecoder(nn.Module):
Returns:
Decoder output torch.Tensor
"""
-
- inputs_embeds = self.embed_tokens(decoder_input_ids)
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(decoder_input_ids)
+ else:
+ decoder_positions = inputs_embeds[:, -1]
# embed positions
embed_pos = self.embed_positions(decoder_positions)
diff --git a/vllm/model_executor/models/florence2.py b/vllm/model_executor/models/florence2.py
index 06912bcfdc8a..b71d0de8d707 100644
--- a/vllm/model_executor/models/florence2.py
+++ b/vllm/model_executor/models/florence2.py
@@ -1,10 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
import math
-from typing import Iterable, Optional, Set, Tuple
+from functools import cached_property
+from typing import (Iterable, List, Literal, Mapping, Optional, OrderedDict,
+ Set, Tuple, TypedDict, Union)
import torch
import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from transformers import BatchFeature, PretrainedConfig
from vllm.config import VllmConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@@ -14,11 +19,567 @@ from vllm.model_executor.models.bart import (BartDecoder, BartEncoder,
BartParallelLMHead,
BartScaledWordEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
+from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
+from vllm.multimodal.parse import MultiModalDataDict, MultiModalDataItems
+from vllm.multimodal.processing import (BaseProcessingInfo,
+ EncDecMultiModalProcessor,
+ PromptReplacement,
+ PromptReplacementDetails)
+from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
-from .utils import AutoWeightsLoader
+from .interfaces import SupportsMultiModal
+from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings
+class Florence2ImagePixelInputs(TypedDict):
+ type: Literal["pixel_values"]
+ data: torch.Tensor
+ """Shape: (batch_size, num_channel, height, width)"""
+
+
+# ViT implementation are all copied from
+# https://huggingface.co/microsoft/Florence-2-base/blob/main/modeling_florence2.py
+class LearnedAbsolutePositionEmbedding2D(nn.Module):
+ """
+ This module learns positional embeddings up to a fixed maximum size.
+ """
+
+ def __init__(self, embedding_dim=256, num_pos=50):
+ super().__init__()
+ self.row_embeddings = nn.Embedding(num_pos, embedding_dim // 2)
+ self.column_embeddings = nn.Embedding(
+ num_pos, embedding_dim - (embedding_dim // 2))
+
+ def forward(self, pixel_values):
+ """
+ pixel_values: (batch_size, height, width, num_channels)
+ returns: (batch_size, height, width, embedding_dim * 2)
+ """
+ if len(pixel_values.shape) != 4:
+ raise ValueError('pixel_values must be a 4D tensor')
+ height, width = pixel_values.shape[1:3]
+ width_values = torch.arange(width, device=pixel_values.device)
+ height_values = torch.arange(height, device=pixel_values.device)
+ x_emb = self.column_embeddings(width_values)
+ y_emb = self.row_embeddings(height_values)
+ # (height, width, embedding_dim * 2)
+ pos = torch.cat([
+ x_emb.unsqueeze(0).repeat(height, 1, 1),
+ y_emb.unsqueeze(1).repeat(1, width, 1)
+ ],
+ dim=-1)
+ # (embedding_dim * 2, height, width)
+ pos = pos.permute(2, 0, 1)
+ pos = pos.unsqueeze(0)
+ # (batch_size, embedding_dim * 2, height, width)
+ pos = pos.repeat(pixel_values.shape[0], 1, 1, 1)
+ # (batch_size, height, width, embedding_dim * 2)
+ pos = pos.permute(0, 2, 3, 1)
+ return pos
+
+
+class PositionalEmbeddingCosine1D(nn.Module):
+ """
+ This class implements a very simple positional encoding. It follows closely
+ the encoder from the link below:
+ https://pytorch.org/tutorials/beginner/translation_transformer.html
+ Args:
+ embed_dim: The dimension of the embeddings.
+ dropout_prob: The dropout probability.
+ max_seq_len: The maximum length to precompute the positional encodings.
+ """
+
+ def __init__(self, embed_dim: int = 512, max_seq_len: int = 1024) -> None:
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.max_seq_len = max_seq_len
+ # Generate the sinusoidal arrays.
+ factor = math.log(10000)
+ denominator = torch.exp(-factor * torch.arange(0, self.embed_dim, 2) /
+ self.embed_dim)
+ # Matrix where rows correspond to a positional embedding as a function
+ # of the position index (i.e., the row index).
+ frequencies = \
+ torch.arange(0, self.max_seq_len) \
+ .reshape(self.max_seq_len, 1) * denominator
+ pos_idx_to_embed = torch.zeros((self.max_seq_len, self.embed_dim))
+ # Populate uneven entries.
+ pos_idx_to_embed[:, 0::2] = torch.sin(frequencies)
+ pos_idx_to_embed[:, 1::2] = torch.cos(frequencies)
+ # Save the positional embeddings in a constant buffer.
+ # self.register_buffer("pos_idx_to_embed", pos_idx_to_embed)
+ self.pos_idx_to_embed = nn.Parameter(pos_idx_to_embed,
+ requires_grad=False)
+
+ def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ seq_embeds: The sequence embeddings in order. Allowed size:
+ 1. [T, D], where T is the length of the sequence, and D is the
+ frame embedding dimension.
+ 2. [B, T, D], where B is the batch size and T and D are the
+ same as above.
+ Returns a tensor of with the same dimensions as the input: i.e.,
+ [1, T, D] or [T, D].
+ """
+ shape_len = len(seq_embeds.shape)
+ assert 2 <= shape_len <= 3
+ len_seq = seq_embeds.size(-2)
+ assert len_seq <= self.max_seq_len
+ pos_embeds = self.pos_idx_to_embed[0:seq_embeds.size(-2), :]
+ # Adapt pre-computed positional embeddings to the input.
+ if shape_len == 3:
+ pos_embeds = pos_embeds.view(
+ (1, pos_embeds.size(0), pos_embeds.size(1)))
+ return pos_embeds
+
+
+class MySequential(nn.Sequential):
+
+ def forward(self, *inputs):
+ for module in self._modules.values():
+ if isinstance(inputs, tuple):
+ inputs = module(*inputs)
+ else:
+ inputs = module(inputs)
+ return inputs
+
+
+class PreNorm(nn.Module):
+
+ def __init__(self, norm, fn):
+ super().__init__()
+ self.norm = norm
+ self.fn = fn
+
+ def forward(self, x, *args, **kwargs):
+ shortcut = x
+ if self.norm is not None:
+ x, size = self.fn(self.norm(x), *args, **kwargs)
+ else:
+ x, size = self.fn(x, *args, **kwargs)
+
+ x = shortcut + x
+
+ return x, size
+
+
+class Mlp(nn.Module):
+
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.net = nn.Sequential(
+ OrderedDict([("fc1", nn.Linear(in_features, hidden_features)),
+ ("act", act_layer()),
+ ("fc2", nn.Linear(hidden_features, out_features))]))
+
+ def forward(self, x, size):
+ return self.net(x), size
+
+
+class DepthWiseConv2d(nn.Module):
+
+ def __init__(
+ self,
+ dim_in,
+ kernel_size,
+ padding,
+ stride,
+ bias=True,
+ ):
+ super().__init__()
+ self.dw = nn.Conv2d(dim_in,
+ dim_in,
+ kernel_size=kernel_size,
+ padding=padding,
+ groups=dim_in,
+ stride=stride,
+ bias=bias)
+
+ def forward(self, x, size):
+ B, N, C = x.shape
+ H, W = size
+ assert N == H * W
+
+ x = self.dw(x.transpose(1, 2).view(B, C, H, W))
+ size = (x.size(-2), x.size(-1))
+ x = x.flatten(2).transpose(1, 2)
+ return x, size
+
+
+class ConvEmbed(nn.Module):
+ """ Image to Patch Embedding
+ """
+
+ def __init__(self,
+ patch_size=7,
+ in_chans=3,
+ embed_dim=64,
+ stride=4,
+ padding=2,
+ norm_layer=None,
+ pre_norm=True):
+ super().__init__()
+ self.patch_size = patch_size
+
+ self.proj = nn.Conv2d(in_chans,
+ embed_dim,
+ kernel_size=patch_size,
+ stride=stride,
+ padding=padding)
+
+ dim_norm = in_chans if pre_norm else embed_dim
+ self.norm = norm_layer(dim_norm) if norm_layer else None
+
+ self.pre_norm = pre_norm
+
+ def forward(self, x, size):
+ H, W = size
+ if len(x.size()) == 3:
+ if self.norm and self.pre_norm:
+ x = self.norm(x)
+ x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W)
+
+ x = self.proj(x)
+
+ _, _, H, W = x.shape
+ x = rearrange(x, 'b c h w -> b (h w) c')
+ if self.norm and not self.pre_norm:
+ x = self.norm(x)
+
+ return x, (H, W)
+
+
+class ChannelAttention(nn.Module):
+
+ def __init__(self, dim, groups=8, qkv_bias=True):
+ super().__init__()
+
+ self.groups = groups
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.proj = nn.Linear(dim, dim)
+
+ def forward(self, x, size):
+ B, N, C = x.shape
+
+ qkv = self.qkv(x).reshape(B, N, 3, self.groups,
+ C // self.groups).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2]
+
+ q = q * (float(N)**-0.5)
+ attention = q.transpose(-1, -2) @ k
+ attention = attention.softmax(dim=-1)
+ x = (attention @ v.transpose(-1, -2)).transpose(-1, -2)
+ x = x.transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ return x, size
+
+
+class ChannelBlock(nn.Module):
+
+ def __init__(self,
+ dim,
+ groups,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ drop_path_rate=0.,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ conv_at_attn=True,
+ conv_at_ffn=True):
+ super().__init__()
+
+ self.conv1 = PreNorm(None, DepthWiseConv2d(
+ dim, 3, 1, 1)) if conv_at_attn else None
+ self.channel_attn = PreNorm(
+ norm_layer(dim),
+ ChannelAttention(dim, groups=groups, qkv_bias=qkv_bias),
+ )
+ self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1,
+ 1)) if conv_at_ffn else None
+ self.ffn = PreNorm(
+ norm_layer(dim),
+ Mlp(in_features=dim,
+ hidden_features=int(dim * mlp_ratio),
+ act_layer=act_layer),
+ )
+
+ def forward(self, x, size):
+ if self.conv1:
+ x, size = self.conv1(x, size)
+ x, size = self.channel_attn(x, size)
+
+ if self.conv2:
+ x, size = self.conv2(x, size)
+ x, size = self.ffn(x, size)
+
+ return x, size
+
+
+def window_partition(x, window_size: int):
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size,
+ C)
+ windows = x.permute(0, 1, 3, 2, 4,
+ 5).contiguous().view(-1, window_size, window_size, C)
+ return windows
+
+
+def window_reverse(windows, batch_size: int, window_size: int, H: int, W: int):
+ B = batch_size
+
+ x = windows.view(B, H // window_size, W // window_size, window_size,
+ window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+
+class WindowAttention(nn.Module):
+
+ def __init__(self, dim, num_heads, window_size, qkv_bias=True):
+
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = float(head_dim)**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.proj = nn.Linear(dim, dim)
+
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, size):
+
+ H, W = size
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+
+ x = x.view(B, H, W, C)
+
+ pad_l = pad_t = 0
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
+ _, Hp, Wp, _ = x.shape
+
+ x = window_partition(x, self.window_size)
+ x = x.view(-1, self.window_size * self.window_size, C)
+
+ # W-MSA/SW-MSA
+ # attn_windows = self.attn(x_windows)
+
+ B_, N, C = x.shape
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
+ C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2]
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+ attn = self.softmax(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+ x = self.proj(x)
+
+ # merge windows
+ x = x.view(-1, self.window_size, self.window_size, C)
+ x = window_reverse(x, B, self.window_size, Hp, Wp)
+
+ if pad_r > 0 or pad_b > 0:
+ x = x[:, :H, :W, :].contiguous()
+
+ x = x.view(B, H * W, C)
+
+ return x, size
+
+
+class SpatialBlock(nn.Module):
+
+ def __init__(self,
+ dim,
+ num_heads,
+ window_size,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ drop_path_rate=0.,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ conv_at_attn=True,
+ conv_at_ffn=True):
+ super().__init__()
+
+ self.conv1 = PreNorm(None, DepthWiseConv2d(
+ dim, 3, 1, 1)) if conv_at_attn else None
+ self.window_attn = PreNorm(
+ norm_layer(dim),
+ WindowAttention(dim, num_heads, window_size, qkv_bias=qkv_bias),
+ )
+ self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1,
+ 1)) if conv_at_ffn else None
+ self.ffn = PreNorm(
+ norm_layer(dim),
+ Mlp(in_features=dim,
+ hidden_features=int(dim * mlp_ratio),
+ act_layer=act_layer),
+ )
+
+ def forward(self, x, size):
+ if self.conv1:
+ x, size = self.conv1(x, size)
+ x, size = self.window_attn(x, size)
+
+ if self.conv2:
+ x, size = self.conv2(x, size)
+ x, size = self.ffn(x, size)
+ return x, size
+
+
+class DaViT(nn.Module):
+
+ def __init__(
+ self,
+ in_chans=3,
+ num_classes=1000,
+ depths=(1, 1, 3, 1),
+ patch_size=(7, 2, 2, 2),
+ patch_stride=(4, 2, 2, 2),
+ patch_padding=(3, 0, 0, 0),
+ patch_prenorm=(False, False, False, False),
+ embed_dims=(64, 128, 192, 256),
+ num_heads=(3, 6, 12, 24),
+ num_groups=(3, 6, 12, 24),
+ window_size=7,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ drop_path_rate=0.1,
+ norm_layer=nn.LayerNorm,
+ enable_checkpoint=False,
+ conv_at_attn=True,
+ conv_at_ffn=True,
+ ):
+ super().__init__()
+
+ self.num_classes = num_classes
+ self.embed_dims = embed_dims
+ self.num_heads = num_heads
+ self.num_groups = num_groups
+ self.num_stages = len(self.embed_dims)
+ self.enable_checkpoint = enable_checkpoint
+ assert self.num_stages == len(self.num_heads) == len(self.num_groups)
+
+ num_stages = len(embed_dims)
+ dpr = [
+ x.item() for x in torch.linspace(0, drop_path_rate,
+ sum(depths) * 2)
+ ]
+
+ depth_offset = 0
+ convs = []
+ blocks = []
+ for i in range(num_stages):
+ conv_embed = ConvEmbed(
+ patch_size=patch_size[i],
+ stride=patch_stride[i],
+ padding=patch_padding[i],
+ in_chans=in_chans if i == 0 else self.embed_dims[i - 1],
+ embed_dim=self.embed_dims[i],
+ norm_layer=norm_layer,
+ pre_norm=patch_prenorm[i])
+ convs.append(conv_embed)
+
+ block = MySequential(*[
+ MySequential(
+ OrderedDict([('spatial_block',
+ SpatialBlock(
+ embed_dims[i],
+ num_heads[i],
+ window_size,
+ drop_path_rate=dpr[depth_offset + j * 2],
+ qkv_bias=qkv_bias,
+ mlp_ratio=mlp_ratio,
+ conv_at_attn=conv_at_attn,
+ conv_at_ffn=conv_at_ffn,
+ )),
+ ('channel_block',
+ ChannelBlock(
+ embed_dims[i],
+ num_groups[i],
+ drop_path_rate=dpr[depth_offset + j * 2 +
+ 1],
+ qkv_bias=qkv_bias,
+ mlp_ratio=mlp_ratio,
+ conv_at_attn=conv_at_attn,
+ conv_at_ffn=conv_at_ffn,
+ ))])) for j in range(depths[i])
+ ])
+ blocks.append(block)
+ depth_offset += depths[i] * 2
+
+ self.convs = nn.ModuleList(convs)
+ self.blocks = nn.ModuleList(blocks)
+
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
+
+ @property
+ def dim_out(self):
+ return self.embed_dims[-1]
+
+ def forward_features_unpool(self, x):
+ """
+ forward until avg pooling
+ Args:
+ x (_type_): input image tensor
+ """
+ input_size = (x.size(2), x.size(3))
+ for conv, block in zip(self.convs, self.blocks):
+ x, input_size = conv(x, input_size)
+ x, input_size = block(x, input_size)
+ return x
+
+ def forward_features(self, x):
+ x = self.forward_features_unpool(x)
+
+ # (batch_size, num_tokens, token_dim)
+ x = self.avgpool(x.transpose(1, 2))
+ # (batch_size, 1, num_tokens)
+ x = torch.flatten(x, 1)
+ x = self.norms(x)
+
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.head(x)
+ return x
+
+ @classmethod
+ def from_config(cls, config):
+ return cls(
+ depths=config.depths,
+ embed_dims=config.dim_embed,
+ num_heads=config.num_heads,
+ num_groups=config.num_groups,
+ patch_size=config.patch_size,
+ patch_stride=config.patch_stride,
+ patch_padding=config.patch_padding,
+ patch_prenorm=config.patch_prenorm,
+ drop_path_rate=config.drop_path_rate,
+ window_size=config.window_size,
+ )
+
+
+# Language backbone and processor implementation
class Florence2LanguageModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@@ -47,9 +608,14 @@ class Florence2LanguageModel(nn.Module):
self.encoder.embed_tokens.weight = self.shared.weight
self.decoder.embed_tokens.weight = self.shared.weight
- def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
- encoder_input_ids: torch.Tensor,
- encoder_positions: torch.Tensor) -> torch.Tensor:
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ encoder_input_ids: torch.Tensor,
+ encoder_positions: torch.Tensor,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
r"""
Args:
input_ids
@@ -68,11 +634,12 @@ class Florence2LanguageModel(nn.Module):
encoder_hidden_states = None
- if encoder_input_ids.numel() > 0:
+ if inputs_embeds is not None or encoder_input_ids.numel() > 0:
# Run encoder attention if a non-zero number of encoder tokens
# are provided as input
encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
- positions=encoder_positions)
+ positions=encoder_positions,
+ inputs_embeds=inputs_embeds)
# decoder outputs consists of
# (dec_features, past_key_value, dec_hidden, dec_attn)
@@ -112,6 +679,7 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
positions: torch.Tensor,
encoder_input_ids: torch.Tensor,
encoder_positions: torch.Tensor,
+ inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
r"""
@@ -127,8 +695,15 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
Returns:
Output torch.Tensor
"""
- return self.model(input_ids, positions, encoder_input_ids,
- encoder_positions)
+
+ return self.model(input_ids,
+ positions,
+ encoder_input_ids,
+ encoder_positions,
+ inputs_embeds=inputs_embeds)
+
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.model.encoder.embed_tokens(input_ids)
def compute_logits(
self,
@@ -177,21 +752,312 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
return loaded_params
-class Florence2ForConditionalGeneration(nn.Module):
+class Florence2ProcessingInfo(BaseProcessingInfo):
+
+ def get_hf_config(self):
+ return self.ctx.get_hf_config()
+
+ def get_hf_processor(self):
+ return self.ctx.get_hf_processor()
+
+ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
+ return {"image": 1}
+
+ def get_max_image_tokens(self) -> int:
+ processor_config = self.ctx.get_hf_image_processor_config()
+ return processor_config["image_seq_length"]
+
+ def get_mm_max_tokens_per_item(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> Mapping[str, int]:
+ return {"image": self.get_max_image_tokens()}
+
+
+class Florence2DummyInputsBuilder(
+ BaseDummyInputsBuilder[Florence2ProcessingInfo]):
+
+ def get_dummy_processor_inputs(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> ProcessorInputs:
+ num_images = mm_counts.get("image", 0)
+
+ target_width = target_height = self.info.get_hf_config().projection_dim
+
+ mm_data = {
+ "image":
+ self._get_dummy_images(width=target_width,
+ height=target_height,
+ num_images=num_images)
+ }
+
+ return ProcessorInputs(
+ prompt_text="",
+ mm_data=mm_data,
+ )
+
+
+class Florence2MultiModalProcessor(
+ EncDecMultiModalProcessor[Florence2ProcessingInfo]):
+
+ def _hf_processor_applies_repl(
+ self,
+ prompt_text: str,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> bool:
+ return False
+
+ def create_encoder_prompt(
+ self,
+ prompt: Union[str, list[int]],
+ mm_data: MultiModalDataDict,
+ ) -> Union[str, list[int]]:
+ return prompt
+
+ def create_decoder_prompt(
+ self,
+ prompt: Union[str, list[int]],
+ mm_data: MultiModalDataDict,
+ ) -> Union[str, list[int]]:
+ return [self.info.get_hf_config().eos_token_id]
+
+ def _call_hf_processor(
+ self,
+ prompt: str,
+ mm_data: Mapping[str, object],
+ mm_kwargs: Mapping[str, object],
+ ) -> BatchFeature:
+ if mm_data:
+ processed_outputs = super()._call_hf_processor(
+ prompt, mm_data, mm_kwargs)
+ else:
+ hf_processor = self.info.get_hf_processor()
+ tokenizer = hf_processor.tokenizer
+ prompt = hf_processor._construct_prompts([prompt])[0]
+ processed_outputs = tokenizer(prompt,
+ add_special_tokens=True,
+ return_tensors="pt")
+ return processed_outputs
+
+ def _get_mm_fields_config(
+ self,
+ hf_inputs: BatchFeature,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> Mapping[str, MultiModalFieldConfig]:
+ return dict(pixel_values=MultiModalFieldConfig.batched("image"))
+
+ def _get_prompt_replacements(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ out_mm_kwargs: MultiModalKwargs,
+ ) -> list[PromptReplacement]:
+ hf_config = self.info.get_hf_config()
+ pad_token_id = hf_config.pad_token_id
+ bos_token_id = hf_config.bos_token_id
+ num_image_tokens = self.info.get_max_image_tokens()
+ image_tokens = [pad_token_id] * num_image_tokens
+
+ return [
+ PromptReplacement(
+ modality="image",
+ target=[bos_token_id],
+ replacement=PromptReplacementDetails(
+ full=image_tokens + [bos_token_id],
+ features=image_tokens,
+ ),
+ )
+ ]
+
+
+@MULTIMODAL_REGISTRY.register_processor(
+ Florence2MultiModalProcessor,
+ info=Florence2ProcessingInfo,
+ dummy_inputs=Florence2DummyInputsBuilder)
+class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
+ processor_config = vllm_config.model_config.hf_image_processor_config
- # TODO(Isotr0py): Add vision backbone
+ self.config = config
+ self.vision_config = config.vision_config
+ self.processor_config = processor_config
+ assert config.vision_config.model_type == 'davit', (
+ 'only DaViT is supported for now')
+ self.vision_tower = DaViT.from_config(config=config.vision_config)
+ self._build_image_projection_layers(config)
self.language_model = Florence2LanguageForConditionalGeneration(
vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=f"{prefix}.language_model",
)
+ self.pad_token_id = config.pad_token_id
- @property
+ def _build_image_projection_layers(self, config: PretrainedConfig):
+ image_dim_out = config.vision_config.dim_embed[-1]
+ dim_projection = config.vision_config.projection_dim
+ self.image_projection = nn.Parameter(
+ torch.empty(image_dim_out, dim_projection))
+ self.image_proj_norm = nn.LayerNorm(dim_projection)
+ image_pos_embed_config = config.vision_config.image_pos_embed
+ if image_pos_embed_config['type'] == 'learned_abs_2d':
+ self.image_pos_embed = LearnedAbsolutePositionEmbedding2D(
+ embedding_dim=image_dim_out,
+ num_pos=image_pos_embed_config['max_pos_embeddings'])
+ else:
+ raise NotImplementedError("Florence2 only supports learned_abs_2d "
+ "as image position embedding.")
+
+ self.image_feature_source = config.vision_config.image_feature_source
+
+ # temporal embedding
+ visual_temporal_embedding_config = (
+ self.vision_config.visual_temporal_embedding)
+ if visual_temporal_embedding_config['type'] == 'COSINE':
+ self.visual_temporal_embed = PositionalEmbeddingCosine1D(
+ embed_dim=image_dim_out,
+ max_seq_len=visual_temporal_embedding_config[
+ 'max_temporal_embeddings'])
+ else:
+ raise NotImplementedError(
+ 'Florence2 only supports COSINE as temporal embedding.')
+
+ @cached_property
def sampler(self):
- return self.language_model.sampler
+ if hasattr(self.language_model, "sampler"):
+ return self.language_model.sampler
+ return get_sampler()
+
+ def _validate_pixel_values(
+ self, data: Union[torch.Tensor, List[torch.Tensor]]
+ ) -> Union[torch.Tensor, List[torch.Tensor]]:
+
+ size = self.processor_config["size"]
+ h, w = size["height"], size["width"]
+ expected_dims = (3, h, w)
+
+ def _validate_shape(d: torch.Tensor):
+ actual_dims = tuple(d.shape)
+
+ if actual_dims != expected_dims:
+ expected_expr = tuple(*map(str, expected_dims))
+ raise ValueError(
+ "The expected shape of pixel values per batch "
+ f"is {expected_expr}. You supplied {tuple(d.shape)}.")
+
+ for d in data:
+ _validate_shape(d)
+
+ return data
+
+ def _parse_and_validate_image_input(self, **kwargs: object):
+ pixel_values: Optional[Union[List[List[torch.Tensor]],
+ List[torch.Tensor],
+ torch.Tensor]] = kwargs.pop(
+ "pixel_values", None)
+ image_embeds: Optional[Union[List[List[torch.Tensor]],
+ List[torch.Tensor],
+ torch.Tensor]] = kwargs.pop(
+ "image_embeds", None)
+
+ if pixel_values is None and image_embeds is None:
+ return None
+
+ if pixel_values is not None and image_embeds is not None:
+ raise ValueError(
+ "Both pixel values and image embeds are provided.")
+
+ if pixel_values is not None:
+ return Florence2ImagePixelInputs(
+ type="pixel_values",
+ data=self._validate_pixel_values(
+ flatten_bn(pixel_values, concat=True)),
+ )
+
+ if image_embeds is not None:
+ raise NotImplementedError
+
+ raise AssertionError("This line should be unreachable.")
+
+ def _encode_image(self, pixel_values: torch.Tensor) -> torch.Tensor:
+ dtype = next(self.vision_tower.parameters()).dtype
+ pixel_values = pixel_values.to(dtype)
+
+ batch_size, T = pixel_values.size(0), 1
+ x = self.vision_tower.forward_features_unpool(pixel_values)
+ if self.image_pos_embed is not None:
+ x = x.view(batch_size * T, -1, x.shape[-1])
+ num_tokens = x.shape[-2]
+ h, w = int(num_tokens**0.5), int(num_tokens**0.5)
+ assert h * w == num_tokens, (
+ 'only support square feature maps for now')
+ x = x.view(batch_size * T, h, w, x.shape[-1])
+ pos_embed = self.image_pos_embed(x)
+ x = x + pos_embed
+ x = x.view(batch_size, T * h * w, x.shape[-1])
+
+ if self.visual_temporal_embed is not None:
+ visual_temporal_embed = self.visual_temporal_embed(
+ x.view(batch_size, T, -1, x.shape[-1])[:, :, 0])
+ x = x.view(batch_size, T, -1,
+ x.shape[-1]) + visual_temporal_embed.view(
+ 1, T, 1, x.shape[-1])
+
+ x_feat_dict = {}
+
+ spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2)
+ x_feat_dict['spatial_avg_pool'] = spatial_avg_pool_x
+
+ temporal_avg_pool_x = x.view(batch_size, T, -1,
+ x.shape[-1]).mean(dim=1)
+ x_feat_dict['temporal_avg_pool'] = temporal_avg_pool_x
+
+ x = x.view(batch_size, T, -1, x.shape[-1])[:, -1]
+ x_feat_dict['last_frame'] = x
+
+ new_x = []
+ for _image_feature_source in self.image_feature_source:
+ if _image_feature_source not in x_feat_dict:
+ raise ValueError('invalid image feature source: {}'.format(
+ _image_feature_source))
+ new_x.append(x_feat_dict[_image_feature_source])
+
+ x = torch.cat(new_x, dim=1)
+
+ x = x @ self.image_projection
+ x = self.image_proj_norm(x)
+
+ return x
+
+ def _process_image_input(
+ self, image_input: Florence2ImagePixelInputs) -> torch.Tensor:
+ assert image_input["type"] == "pixel_values"
+ pixel_values = image_input["data"]
+ return self._encode_image(pixel_values)
+
+ def get_multimodal_embeddings(self, **kwargs: object) -> torch.Tensor:
+ image_input = self._parse_and_validate_image_input(**kwargs)
+ if image_input is None:
+ return None
+ vision_embeddings = self._process_image_input(image_input)
+ return vision_embeddings
+
+ def get_input_embeddings(
+ self,
+ input_ids: torch.Tensor,
+ multimodal_embeddings: Optional[NestedTensors] = None,
+ ) -> torch.Tensor:
+ inputs_embeds = self.language_model.get_input_embeddings(input_ids)
+ if multimodal_embeddings is not None:
+ inputs_embeds = merge_multimodal_embeddings(
+ input_ids, inputs_embeds, multimodal_embeddings,
+ self.pad_token_id)
+ return inputs_embeds
def forward(
self,
@@ -216,8 +1082,19 @@ class Florence2ForConditionalGeneration(nn.Module):
Returns:
Output torch.Tensor
"""
- return self.language_model(input_ids, positions, encoder_input_ids,
- encoder_positions)
+ vision_embeddings = self.get_multimodal_embeddings(**kwargs)
+ if encoder_input_ids.numel() > 0 or vision_embeddings is not None:
+ inputs_embeds = self.get_input_embeddings(encoder_input_ids,
+ vision_embeddings)
+ else:
+ inputs_embeds = None
+
+ hidden_states = self.language_model(input_ids,
+ positions,
+ encoder_input_ids,
+ encoder_positions,
+ inputs_embeds=inputs_embeds)
+ return hidden_states
def compute_logits(
self,
@@ -236,9 +1113,5 @@ class Florence2ForConditionalGeneration(nn.Module):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
- skip_prefixes = [
- 'image_projection', "vision_tower", "image_proj_norm",
- "image_pos_embed", "visual_temporal_embed"
- ]
- loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
+ loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index 58155905a7b7..75e31d557dd1 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -105,7 +105,6 @@ _TEXT_GENERATION_MODELS = {
# [Encoder-decoder]
"BartModel": ("bart", "BartForConditionalGeneration"),
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
- "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501
}
_EMBEDDING_MODELS = {
@@ -182,6 +181,7 @@ _MULTIMODAL_MODELS = {
"Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501
"UltravoxModel": ("ultravox", "UltravoxModel"),
# [Encoder-decoder]
+ "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
"WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501
}
diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py
index 93756364dea1..60b000e2b34f 100644
--- a/vllm/multimodal/processing.py
+++ b/vllm/multimodal/processing.py
@@ -1303,6 +1303,14 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
"""
raise NotImplementedError
+ def create_decoder_prompt(
+ self,
+ prompt: Union[str, list[int]],
+ mm_data: MultiModalDataDict,
+ ) -> Union[str, list[int]]:
+ """Create input prompt for the decoder."""
+ return prompt
+
def apply(
self,
prompt: Union[str, list[int]],
@@ -1323,17 +1331,15 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
hf_processor_mm_kwargs,
)
- # We assumed the decoder prompt text is copied from
- # the original encoder prompt without extra process
tokenizer = self.info.get_tokenizer()
- if isinstance(prompt, str):
- decoder_prompt = prompt
+ decoder_prompt = self.create_decoder_prompt(prompt, mm_data)
+ if isinstance(decoder_prompt, str):
decoder_prompt_ids = encode_tokens(tokenizer,
- prompt,
+ decoder_prompt,
add_special_tokens=False)
else:
- decoder_prompt = decode_tokens(tokenizer, prompt)
- decoder_prompt_ids = prompt
+ decoder_prompt_ids = decoder_prompt
+ decoder_prompt = decode_tokens(tokenizer, decoder_prompt)
mm_inputs = MultiModalEncDecInputs(
encoder_prompt=encoder_inputs["prompt"],
diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py
index 093f8b7a8179..3178b0f8c3e6 100644
--- a/vllm/multimodal/profiling.py
+++ b/vllm/multimodal/profiling.py
@@ -204,9 +204,11 @@ class MultiModalProfiler(Generic[_I]):
"and/or reduce `mm_counts`.", seq_len, total_len,
total_placeholders_by_modality)
+ num_tokens_to_pad = max(total_len, seq_len) - total_len
+ prompt_token_ids.extend([0] * num_tokens_to_pad)
+
return DummyData(
- seq_data=SequenceData.from_prompt_token_counts(
- (0, max(seq_len, total_len))),
+ seq_data=SequenceData.from_seqs(prompt_token_ids),
multi_modal_data=None,
multi_modal_placeholders=None,
)