From edf309ebbe25fcf55569c7fe94e3fe78428a8244 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Thu, 27 Feb 2025 18:06:41 +0800 Subject: [PATCH] [VLM] Support multimodal inputs for Florence-2 models (#13320) --- docs/source/models/supported_models.md | 7 + .../offline_inference/florence2_inference.py | 39 +- examples/offline_inference/vision_language.py | 17 + tests/conftest.py | 6 +- .../audio_language/test_ultravox.py | 4 +- .../vision_language/test_florence2.py | 133 ++- .../multimodal/processing/test_common.py | 5 +- tests/models/registry.py | 10 +- vllm/model_executor/models/bart.py | 27 +- vllm/model_executor/models/florence2.py | 913 +++++++++++++++++- vllm/model_executor/models/registry.py | 2 +- vllm/multimodal/processing.py | 20 +- vllm/multimodal/profiling.py | 6 +- 13 files changed, 1075 insertions(+), 114 deletions(-) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 9959f7233e86..4b1f3e180ed5 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -715,6 +715,13 @@ See [this page](#generative-models) for more information on how to use generativ * * ✅︎ * ✅︎ +- * `Florence2ForConditionalGeneration` + * Florence-2 + * T + I + * `microsoft/Florence-2-base`, `microsoft/Florence-2-large` etc. + * + * + * - * `FuyuForCausalLM` * Fuyu * T + I diff --git a/examples/offline_inference/florence2_inference.py b/examples/offline_inference/florence2_inference.py index 58610b0fd2a5..27aceee43cbf 100644 --- a/examples/offline_inference/florence2_inference.py +++ b/examples/offline_inference/florence2_inference.py @@ -1,34 +1,45 @@ # SPDX-License-Identifier: Apache-2.0 -''' +""" Demonstrate prompting of text-to-text encoder/decoder models, specifically Florence-2 -''' +""" # TODO(Isotr0py): # Move to offline_inference/vision_language.py # after porting vision backbone from vllm import LLM, SamplingParams - -dtype = "float" +from vllm.assets.image import ImageAsset # Create a Florence-2 encoder/decoder model instance llm = LLM( - model="microsoft/Florence-2-base", - tokenizer="facebook/bart-base", - dtype=dtype, + model="microsoft/Florence-2-large", + tokenizer="facebook/bart-large", + max_num_seqs=8, trust_remote_code=True, ) prompts = [ - "", "", "", - "", "", "", - "", "", "" + { # implicit prompt with task token + "prompt": "", + "multi_modal_data": { + "image": ImageAsset("stop_sign").pil_image + }, + }, + { # explicit encoder/decoder prompt + "encoder_prompt": { + "prompt": "Describe in detail what is shown in the image.", + "multi_modal_data": { + "image": ImageAsset("cherry_blossom").pil_image + }, + }, + "decoder_prompt": "", + }, ] # Create a sampling params object. sampling_params = SamplingParams( temperature=0, top_p=1.0, min_tokens=0, - max_tokens=20, + max_tokens=128, ) # Generate output tokens from the prompts. The output is a list of @@ -38,9 +49,5 @@ outputs = llm.generate(prompts, sampling_params) # Print the outputs. for output in outputs: - prompt = output.prompt - encoder_prompt = output.encoder_prompt generated_text = output.outputs[0].text - print(f"Encoder prompt: {encoder_prompt!r}, " - f"Decoder prompt: {prompt!r}, " - f"Generated text: {generated_text!r}") + print(f"Generated text: {generated_text!r}") diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 5f05389faf80..e2ec36211b86 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -82,6 +82,22 @@ def run_deepseek_vl2(question: str, modality: str): return llm, prompt, stop_token_ids +# Florence2 +def run_florence2(question: str, modality: str): + assert modality == "image" + + llm = LLM(model="microsoft/Florence-2-large", + tokenizer="facebook/bart-large", + max_num_seqs=8, + trust_remote_code=True, + dtype="bfloat16", + disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) + + prompt = "" + stop_token_ids = None + return llm, prompt, stop_token_ids + + # Fuyu def run_fuyu(question: str, modality: str): assert modality == "image" @@ -571,6 +587,7 @@ model_example_map = { "blip-2": run_blip2, "chameleon": run_chameleon, "deepseek_vl_v2": run_deepseek_vl2, + "florence2": run_florence2, "fuyu": run_fuyu, "glm4v": run_glm4v, "h2ovl_chat": run_h2ovl, diff --git a/tests/conftest.py b/tests/conftest.py index dd339030e5e4..871f0b62c532 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -600,8 +600,8 @@ class HfRunner: if images is not None and images[i] is not None: processor_kwargs["images"] = images[i] - encoder_input_ids = self.wrap_device( - self.processor(**processor_kwargs).input_ids, + encoder_inputs = self.wrap_device( + self.processor(**processor_kwargs), device=self.model.device.type, ) @@ -615,13 +615,13 @@ class HfRunner: ) output = self.model.generate( - encoder_input_ids, decoder_input_ids=decoder_input_ids, use_cache=True, do_sample=False, max_new_tokens=max_tokens, output_hidden_states=True, return_dict_in_generate=True, + **encoder_inputs, **kwargs, ) diff --git a/tests/models/decoder_only/audio_language/test_ultravox.py b/tests/models/decoder_only/audio_language/test_ultravox.py index d1f643a8fdb7..0ea17247028f 100644 --- a/tests/models/decoder_only/audio_language/test_ultravox.py +++ b/tests/models/decoder_only/audio_language/test_ultravox.py @@ -15,7 +15,7 @@ from ....conftest import HfRunner, VllmRunner from ....utils import RemoteOpenAIServer from ...utils import check_logprobs_close -MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b" +MODEL_NAME = "fixie-ai/ultravox-v0_4" AudioTuple = Tuple[np.ndarray, int] @@ -187,7 +187,7 @@ def run_multi_audio_test( @pytest.mark.core_model -@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("vllm_kwargs", [ diff --git a/tests/models/encoder_decoder/vision_language/test_florence2.py b/tests/models/encoder_decoder/vision_language/test_florence2.py index a1d15679918b..de18deab11f6 100644 --- a/tests/models/encoder_decoder/vision_language/test_florence2.py +++ b/tests/models/encoder_decoder/vision_language/test_florence2.py @@ -1,52 +1,59 @@ # SPDX-License-Identifier: Apache-2.0 -from functools import partial -from typing import List, Optional, Tuple, Type +from typing import Optional, Type import pytest from PIL import Image -from vllm.inputs.data import ExplicitEncoderDecoderPrompt +from vllm.inputs.data import ExplicitEncoderDecoderPrompt, TextPrompt +from vllm.multimodal.image import rescale_image_size from vllm.sequence import SampleLogprobs -from ....conftest import HfRunner, VllmRunner +from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets from ...utils import check_logprobs_close -Florence2Prompt = partial(ExplicitEncoderDecoderPrompt, - decoder_prompt=None, - mm_processor_kwargs=None) - MODELS = ["microsoft/Florence-2-base"] # Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer # Therefore, we borrow the BartTokenizer from the original Bart model TOKENIZER = "facebook/bart-base" -PROMPTS = [ - Florence2Prompt(encoder_prompt=""), - Florence2Prompt(encoder_prompt=""), - Florence2Prompt(encoder_prompt=""), - Florence2Prompt(encoder_prompt=""), - Florence2Prompt(encoder_prompt=""), - Florence2Prompt(encoder_prompt=""), - Florence2Prompt(encoder_prompt=""), - Florence2Prompt(encoder_prompt=""), - Florence2Prompt(encoder_prompt=""), -] +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ + "stop_sign": + "", # special task token + "cherry_blossom": + "Describe in detail what is shown in the image.", +}) -def vllm_to_hf_output(vllm_output: Tuple[List[int], str, - Optional[SampleLogprobs]], ): - """Sanitize vllm output to be comparable with hf output.""" - output_ids, output_str, out_logprobs = vllm_output +def get_hf_images_prompts( + prompts_: list[ExplicitEncoderDecoderPrompt[str, TextPrompt]], +) -> tuple[list[ExplicitEncoderDecoderPrompt[str, str]], list[Image.Image]]: + prompts, images = [], [] + for prompt in prompts_: + encoder_prompt = prompt["encoder_prompt"] + prompts.append( + ExplicitEncoderDecoderPrompt( + encoder_prompt=encoder_prompt["prompt"], + decoder_prompt=None, + )) + images.append(encoder_prompt["multi_modal_data"]["image"]) + return prompts, images - hf_output_str = "" + output_str + "" - return output_ids, hf_output_str, out_logprobs +def hf_to_vllm_output(hf_output: tuple[list[int], str, + Optional[SampleLogprobs]]): + """Sanitize hf output to be comparable with vllm output.""" + output_ids, output_str, out_logprobs = hf_output + + output_str = output_str.replace("", "").replace("", "") + output_ids = [ids for ids in output_ids if ids not in [0, 2]] + + return output_ids, output_str, out_logprobs def run_test( hf_runner: Type[HfRunner], vllm_runner: Type[VllmRunner], - prompts: List[ExplicitEncoderDecoderPrompt], + inputs: list[list[ExplicitEncoderDecoderPrompt]], model: str, *, dtype: str, @@ -56,46 +63,76 @@ def run_test( distributed_executor_backend: Optional[str] = None, ) -> None: with vllm_runner(model, + max_num_seqs=8, tokenizer_name=TOKENIZER, dtype=dtype, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, enforce_eager=True) as vllm_model: - vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( - prompts, max_tokens, num_logprobs) + vllm_outputs_per_case = [ + vllm_model.generate_encoder_decoder_greedy_logprobs( + prompts, max_tokens, num_logprobs=num_logprobs) + for prompts in inputs + ] + + hf_inputs = [get_hf_images_prompts(prompts) for prompts in inputs] - # Florence-2 processors require image inputs - dummy_image = Image.new(mode="RGB", size=(2, 2)) with hf_runner(model, dtype=dtype, skip_tokenizer_init=True) as hf_model: hf_model.model.get_output_embeddings = lambda: \ hf_model.model.language_model.lm_head - hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit( - prompts, - max_tokens, - num_logprobs, - images=[dummy_image] * len(prompts), - )) + hf_outputs_per_case = [ + hf_model.generate_encoder_decoder_greedy_logprobs_limit( + prompts, max_tokens, num_logprobs=num_logprobs, images=images) + for prompts, images in hf_inputs + ] - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=[ - vllm_to_hf_output(vllm_output) for vllm_output in vllm_outputs - ], - name_0="hf", - name_1="vllm", - ) + for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, + vllm_outputs_per_case): + check_logprobs_close( + outputs_0_lst=[hf_to_vllm_output(output) for output in hf_outputs], + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) +@pytest.mark.core_model @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float", "bfloat16"]) +@pytest.mark.parametrize( + "size_factors", + [ + # No image + [], + # Single-scale + [1.0], + # Single-scale, batched + [1.0, 1.0, 1.0], + # Multi-scale + [0.25, 0.5, 1.0], + ], +) +@pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_models(hf_runner, vllm_runner, model, dtype, max_tokens, - num_logprobs) -> None: +def test_models(hf_runner: Type[HfRunner], vllm_runner: Type[VllmRunner], + image_assets: _ImageAssets, model: str, + size_factors: list[int], dtype: str, max_tokens: int, + num_logprobs: int) -> None: + images = [asset.pil_image for asset in image_assets] + + inputs_per_image = [[ + ExplicitEncoderDecoderPrompt( + encoder_prompt=TextPrompt( + prompt=prompt, + multi_modal_data={"image": rescale_image_size(image, factor)}), + decoder_prompt=None, + ) for factor in size_factors + ] for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + run_test( hf_runner, vllm_runner, - PROMPTS, + inputs_per_image, model, dtype=dtype, max_tokens=max_tokens, diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index a84999cfbf4f..7534f0c97798 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -29,8 +29,8 @@ def _test_processing_correctness( model_config = ModelConfig( model_id, task="auto", - tokenizer=model_id, - tokenizer_mode="auto", + tokenizer=model_info.tokenizer or model_id, + tokenizer_mode=model_info.tokenizer_mode, trust_remote_code=model_info.trust_remote_code, seed=0, dtype="float16", @@ -151,6 +151,7 @@ def _test_processing_correctness( "Salesforce/blip2-opt-2.7b", "facebook/chameleon-7b", "deepseek-ai/deepseek-vl2-tiny", + "microsoft/Florence-2-base", "adept/fuyu-8b", "THUDM/glm-4v-9b", "h2oai/h2ovl-mississippi-800m", diff --git a/tests/models/registry.py b/tests/models/registry.py index 8614baf18f3b..95bda0293498 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -193,11 +193,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { # [Encoder-decoder] "BartModel": _HfExamplesInfo("facebook/bart-base"), "BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"), - # Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer - # Therefore, we borrow the BartTokenizer from the original Bart model - "Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501 - tokenizer="facebook/bart-base", - trust_remote_code=True), # noqa: E501 } _EMBEDDING_EXAMPLE_MODELS = { @@ -288,6 +283,11 @@ _MULTIMODAL_EXAMPLE_MODELS = { extras={"v0.5": "fixie-ai/ultravox-v0_5-llama-3_2-1b"}, # noqa: E501 trust_remote_code=True), # [Encoder-decoder] + # Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer + # Therefore, we borrow the BartTokenizer from the original Bart model + "Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501 + tokenizer="facebook/bart-base", + trust_remote_code=True), # noqa: E501 "MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501 "WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501 } diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 5d2a8cdcb97d..93452696dca5 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -588,8 +588,12 @@ class BartEncoder(nn.Module): self.layernorm_embedding = nn.LayerNorm(embed_dim) - def forward(self, input_ids: torch.Tensor, - positions: torch.Tensor) -> torch.Tensor: + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: r""" Args: input_ids @@ -602,7 +606,8 @@ class BartEncoder(nn.Module): Decoder output torch.Tensor """ # retrieve input_ids and inputs_embeds - inputs_embeds = self.embed_tokens(input_ids) + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) embed_pos = self.embed_positions(positions) embed_pos = embed_pos.to(inputs_embeds.device) @@ -661,9 +666,13 @@ class BartDecoder(nn.Module): self.layernorm_embedding = nn.LayerNorm(config.d_model) - def forward(self, decoder_input_ids: torch.Tensor, - decoder_positions: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor]) -> torch.Tensor: + def forward( + self, + decoder_input_ids: torch.Tensor, + decoder_positions: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: r""" Args: decoder_input_ids @@ -677,8 +686,10 @@ class BartDecoder(nn.Module): Returns: Decoder output torch.Tensor """ - - inputs_embeds = self.embed_tokens(decoder_input_ids) + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(decoder_input_ids) + else: + decoder_positions = inputs_embeds[:, -1] # embed positions embed_pos = self.embed_positions(decoder_positions) diff --git a/vllm/model_executor/models/florence2.py b/vllm/model_executor/models/florence2.py index 06912bcfdc8a..b71d0de8d707 100644 --- a/vllm/model_executor/models/florence2.py +++ b/vllm/model_executor/models/florence2.py @@ -1,10 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 import math -from typing import Iterable, Optional, Set, Tuple +from functools import cached_property +from typing import (Iterable, List, Literal, Mapping, Optional, OrderedDict, + Set, Tuple, TypedDict, Union) import torch import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from transformers import BatchFeature, PretrainedConfig from vllm.config import VllmConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -14,11 +19,567 @@ from vllm.model_executor.models.bart import (BartDecoder, BartEncoder, BartParallelLMHead, BartScaledWordEmbedding) from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors +from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.parse import MultiModalDataDict, MultiModalDataItems +from vllm.multimodal.processing import (BaseProcessingInfo, + EncDecMultiModalProcessor, + PromptReplacement, + PromptReplacementDetails) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors -from .utils import AutoWeightsLoader +from .interfaces import SupportsMultiModal +from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings +class Florence2ImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: (batch_size, num_channel, height, width)""" + + +# ViT implementation are all copied from +# https://huggingface.co/microsoft/Florence-2-base/blob/main/modeling_florence2.py +class LearnedAbsolutePositionEmbedding2D(nn.Module): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, embedding_dim=256, num_pos=50): + super().__init__() + self.row_embeddings = nn.Embedding(num_pos, embedding_dim // 2) + self.column_embeddings = nn.Embedding( + num_pos, embedding_dim - (embedding_dim // 2)) + + def forward(self, pixel_values): + """ + pixel_values: (batch_size, height, width, num_channels) + returns: (batch_size, height, width, embedding_dim * 2) + """ + if len(pixel_values.shape) != 4: + raise ValueError('pixel_values must be a 4D tensor') + height, width = pixel_values.shape[1:3] + width_values = torch.arange(width, device=pixel_values.device) + height_values = torch.arange(height, device=pixel_values.device) + x_emb = self.column_embeddings(width_values) + y_emb = self.row_embeddings(height_values) + # (height, width, embedding_dim * 2) + pos = torch.cat([ + x_emb.unsqueeze(0).repeat(height, 1, 1), + y_emb.unsqueeze(1).repeat(1, width, 1) + ], + dim=-1) + # (embedding_dim * 2, height, width) + pos = pos.permute(2, 0, 1) + pos = pos.unsqueeze(0) + # (batch_size, embedding_dim * 2, height, width) + pos = pos.repeat(pixel_values.shape[0], 1, 1, 1) + # (batch_size, height, width, embedding_dim * 2) + pos = pos.permute(0, 2, 3, 1) + return pos + + +class PositionalEmbeddingCosine1D(nn.Module): + """ + This class implements a very simple positional encoding. It follows closely + the encoder from the link below: + https://pytorch.org/tutorials/beginner/translation_transformer.html + Args: + embed_dim: The dimension of the embeddings. + dropout_prob: The dropout probability. + max_seq_len: The maximum length to precompute the positional encodings. + """ + + def __init__(self, embed_dim: int = 512, max_seq_len: int = 1024) -> None: + super().__init__() + self.embed_dim = embed_dim + self.max_seq_len = max_seq_len + # Generate the sinusoidal arrays. + factor = math.log(10000) + denominator = torch.exp(-factor * torch.arange(0, self.embed_dim, 2) / + self.embed_dim) + # Matrix where rows correspond to a positional embedding as a function + # of the position index (i.e., the row index). + frequencies = \ + torch.arange(0, self.max_seq_len) \ + .reshape(self.max_seq_len, 1) * denominator + pos_idx_to_embed = torch.zeros((self.max_seq_len, self.embed_dim)) + # Populate uneven entries. + pos_idx_to_embed[:, 0::2] = torch.sin(frequencies) + pos_idx_to_embed[:, 1::2] = torch.cos(frequencies) + # Save the positional embeddings in a constant buffer. + # self.register_buffer("pos_idx_to_embed", pos_idx_to_embed) + self.pos_idx_to_embed = nn.Parameter(pos_idx_to_embed, + requires_grad=False) + + def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor: + """ + Args: + seq_embeds: The sequence embeddings in order. Allowed size: + 1. [T, D], where T is the length of the sequence, and D is the + frame embedding dimension. + 2. [B, T, D], where B is the batch size and T and D are the + same as above. + Returns a tensor of with the same dimensions as the input: i.e., + [1, T, D] or [T, D]. + """ + shape_len = len(seq_embeds.shape) + assert 2 <= shape_len <= 3 + len_seq = seq_embeds.size(-2) + assert len_seq <= self.max_seq_len + pos_embeds = self.pos_idx_to_embed[0:seq_embeds.size(-2), :] + # Adapt pre-computed positional embeddings to the input. + if shape_len == 3: + pos_embeds = pos_embeds.view( + (1, pos_embeds.size(0), pos_embeds.size(1))) + return pos_embeds + + +class MySequential(nn.Sequential): + + def forward(self, *inputs): + for module in self._modules.values(): + if isinstance(inputs, tuple): + inputs = module(*inputs) + else: + inputs = module(inputs) + return inputs + + +class PreNorm(nn.Module): + + def __init__(self, norm, fn): + super().__init__() + self.norm = norm + self.fn = fn + + def forward(self, x, *args, **kwargs): + shortcut = x + if self.norm is not None: + x, size = self.fn(self.norm(x), *args, **kwargs) + else: + x, size = self.fn(x, *args, **kwargs) + + x = shortcut + x + + return x, size + + +class Mlp(nn.Module): + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.net = nn.Sequential( + OrderedDict([("fc1", nn.Linear(in_features, hidden_features)), + ("act", act_layer()), + ("fc2", nn.Linear(hidden_features, out_features))])) + + def forward(self, x, size): + return self.net(x), size + + +class DepthWiseConv2d(nn.Module): + + def __init__( + self, + dim_in, + kernel_size, + padding, + stride, + bias=True, + ): + super().__init__() + self.dw = nn.Conv2d(dim_in, + dim_in, + kernel_size=kernel_size, + padding=padding, + groups=dim_in, + stride=stride, + bias=bias) + + def forward(self, x, size): + B, N, C = x.shape + H, W = size + assert N == H * W + + x = self.dw(x.transpose(1, 2).view(B, C, H, W)) + size = (x.size(-2), x.size(-1)) + x = x.flatten(2).transpose(1, 2) + return x, size + + +class ConvEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, + patch_size=7, + in_chans=3, + embed_dim=64, + stride=4, + padding=2, + norm_layer=None, + pre_norm=True): + super().__init__() + self.patch_size = patch_size + + self.proj = nn.Conv2d(in_chans, + embed_dim, + kernel_size=patch_size, + stride=stride, + padding=padding) + + dim_norm = in_chans if pre_norm else embed_dim + self.norm = norm_layer(dim_norm) if norm_layer else None + + self.pre_norm = pre_norm + + def forward(self, x, size): + H, W = size + if len(x.size()) == 3: + if self.norm and self.pre_norm: + x = self.norm(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W) + + x = self.proj(x) + + _, _, H, W = x.shape + x = rearrange(x, 'b c h w -> b (h w) c') + if self.norm and not self.pre_norm: + x = self.norm(x) + + return x, (H, W) + + +class ChannelAttention(nn.Module): + + def __init__(self, dim, groups=8, qkv_bias=True): + super().__init__() + + self.groups = groups + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + def forward(self, x, size): + B, N, C = x.shape + + qkv = self.qkv(x).reshape(B, N, 3, self.groups, + C // self.groups).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * (float(N)**-0.5) + attention = q.transpose(-1, -2) @ k + attention = attention.softmax(dim=-1) + x = (attention @ v.transpose(-1, -2)).transpose(-1, -2) + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + return x, size + + +class ChannelBlock(nn.Module): + + def __init__(self, + dim, + groups, + mlp_ratio=4., + qkv_bias=True, + drop_path_rate=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + conv_at_attn=True, + conv_at_ffn=True): + super().__init__() + + self.conv1 = PreNorm(None, DepthWiseConv2d( + dim, 3, 1, 1)) if conv_at_attn else None + self.channel_attn = PreNorm( + norm_layer(dim), + ChannelAttention(dim, groups=groups, qkv_bias=qkv_bias), + ) + self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, + 1)) if conv_at_ffn else None + self.ffn = PreNorm( + norm_layer(dim), + Mlp(in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer), + ) + + def forward(self, x, size): + if self.conv1: + x, size = self.conv1(x, size) + x, size = self.channel_attn(x, size) + + if self.conv2: + x, size = self.conv2(x, size) + x, size = self.ffn(x, size) + + return x, size + + +def window_partition(x, window_size: int): + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, + C) + windows = x.permute(0, 1, 3, 2, 4, + 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, batch_size: int, window_size: int, H: int, W: int): + B = batch_size + + x = windows.view(B, H // window_size, W // window_size, window_size, + window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + + def __init__(self, dim, num_heads, window_size, qkv_bias=True): + + super().__init__() + self.dim = dim + self.window_size = window_size + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = float(head_dim)**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, size): + + H, W = size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + x = window_partition(x, self.window_size) + x = x.view(-1, self.window_size * self.window_size, C) + + # W-MSA/SW-MSA + # attn_windows = self.attn(x_windows) + + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + attn = self.softmax(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + + # merge windows + x = x.view(-1, self.window_size, self.window_size, C) + x = window_reverse(x, B, self.window_size, Hp, Wp) + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + return x, size + + +class SpatialBlock(nn.Module): + + def __init__(self, + dim, + num_heads, + window_size, + mlp_ratio=4., + qkv_bias=True, + drop_path_rate=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + conv_at_attn=True, + conv_at_ffn=True): + super().__init__() + + self.conv1 = PreNorm(None, DepthWiseConv2d( + dim, 3, 1, 1)) if conv_at_attn else None + self.window_attn = PreNorm( + norm_layer(dim), + WindowAttention(dim, num_heads, window_size, qkv_bias=qkv_bias), + ) + self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, + 1)) if conv_at_ffn else None + self.ffn = PreNorm( + norm_layer(dim), + Mlp(in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer), + ) + + def forward(self, x, size): + if self.conv1: + x, size = self.conv1(x, size) + x, size = self.window_attn(x, size) + + if self.conv2: + x, size = self.conv2(x, size) + x, size = self.ffn(x, size) + return x, size + + +class DaViT(nn.Module): + + def __init__( + self, + in_chans=3, + num_classes=1000, + depths=(1, 1, 3, 1), + patch_size=(7, 2, 2, 2), + patch_stride=(4, 2, 2, 2), + patch_padding=(3, 0, 0, 0), + patch_prenorm=(False, False, False, False), + embed_dims=(64, 128, 192, 256), + num_heads=(3, 6, 12, 24), + num_groups=(3, 6, 12, 24), + window_size=7, + mlp_ratio=4., + qkv_bias=True, + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + enable_checkpoint=False, + conv_at_attn=True, + conv_at_ffn=True, + ): + super().__init__() + + self.num_classes = num_classes + self.embed_dims = embed_dims + self.num_heads = num_heads + self.num_groups = num_groups + self.num_stages = len(self.embed_dims) + self.enable_checkpoint = enable_checkpoint + assert self.num_stages == len(self.num_heads) == len(self.num_groups) + + num_stages = len(embed_dims) + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, + sum(depths) * 2) + ] + + depth_offset = 0 + convs = [] + blocks = [] + for i in range(num_stages): + conv_embed = ConvEmbed( + patch_size=patch_size[i], + stride=patch_stride[i], + padding=patch_padding[i], + in_chans=in_chans if i == 0 else self.embed_dims[i - 1], + embed_dim=self.embed_dims[i], + norm_layer=norm_layer, + pre_norm=patch_prenorm[i]) + convs.append(conv_embed) + + block = MySequential(*[ + MySequential( + OrderedDict([('spatial_block', + SpatialBlock( + embed_dims[i], + num_heads[i], + window_size, + drop_path_rate=dpr[depth_offset + j * 2], + qkv_bias=qkv_bias, + mlp_ratio=mlp_ratio, + conv_at_attn=conv_at_attn, + conv_at_ffn=conv_at_ffn, + )), + ('channel_block', + ChannelBlock( + embed_dims[i], + num_groups[i], + drop_path_rate=dpr[depth_offset + j * 2 + + 1], + qkv_bias=qkv_bias, + mlp_ratio=mlp_ratio, + conv_at_attn=conv_at_attn, + conv_at_ffn=conv_at_ffn, + ))])) for j in range(depths[i]) + ]) + blocks.append(block) + depth_offset += depths[i] * 2 + + self.convs = nn.ModuleList(convs) + self.blocks = nn.ModuleList(blocks) + + self.avgpool = nn.AdaptiveAvgPool1d(1) + + @property + def dim_out(self): + return self.embed_dims[-1] + + def forward_features_unpool(self, x): + """ + forward until avg pooling + Args: + x (_type_): input image tensor + """ + input_size = (x.size(2), x.size(3)) + for conv, block in zip(self.convs, self.blocks): + x, input_size = conv(x, input_size) + x, input_size = block(x, input_size) + return x + + def forward_features(self, x): + x = self.forward_features_unpool(x) + + # (batch_size, num_tokens, token_dim) + x = self.avgpool(x.transpose(1, 2)) + # (batch_size, 1, num_tokens) + x = torch.flatten(x, 1) + x = self.norms(x) + + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + @classmethod + def from_config(cls, config): + return cls( + depths=config.depths, + embed_dims=config.dim_embed, + num_heads=config.num_heads, + num_groups=config.num_groups, + patch_size=config.patch_size, + patch_stride=config.patch_stride, + patch_padding=config.patch_padding, + patch_prenorm=config.patch_prenorm, + drop_path_rate=config.drop_path_rate, + window_size=config.window_size, + ) + + +# Language backbone and processor implementation class Florence2LanguageModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -47,9 +608,14 @@ class Florence2LanguageModel(nn.Module): self.encoder.embed_tokens.weight = self.shared.weight self.decoder.embed_tokens.weight = self.shared.weight - def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, - encoder_input_ids: torch.Tensor, - encoder_positions: torch.Tensor) -> torch.Tensor: + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: r""" Args: input_ids @@ -68,11 +634,12 @@ class Florence2LanguageModel(nn.Module): encoder_hidden_states = None - if encoder_input_ids.numel() > 0: + if inputs_embeds is not None or encoder_input_ids.numel() > 0: # Run encoder attention if a non-zero number of encoder tokens # are provided as input encoder_hidden_states = self.encoder(input_ids=encoder_input_ids, - positions=encoder_positions) + positions=encoder_positions, + inputs_embeds=inputs_embeds) # decoder outputs consists of # (dec_features, past_key_value, dec_hidden, dec_attn) @@ -112,6 +679,7 @@ class Florence2LanguageForConditionalGeneration(nn.Module): positions: torch.Tensor, encoder_input_ids: torch.Tensor, encoder_positions: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: r""" @@ -127,8 +695,15 @@ class Florence2LanguageForConditionalGeneration(nn.Module): Returns: Output torch.Tensor """ - return self.model(input_ids, positions, encoder_input_ids, - encoder_positions) + + return self.model(input_ids, + positions, + encoder_input_ids, + encoder_positions, + inputs_embeds=inputs_embeds) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.encoder.embed_tokens(input_ids) def compute_logits( self, @@ -177,21 +752,312 @@ class Florence2LanguageForConditionalGeneration(nn.Module): return loaded_params -class Florence2ForConditionalGeneration(nn.Module): +class Florence2ProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config() + + def get_hf_processor(self): + return self.ctx.get_hf_processor() + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": 1} + + def get_max_image_tokens(self) -> int: + processor_config = self.ctx.get_hf_image_processor_config() + return processor_config["image_seq_length"] + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + return {"image": self.get_max_image_tokens()} + + +class Florence2DummyInputsBuilder( + BaseDummyInputsBuilder[Florence2ProcessingInfo]): + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + num_images = mm_counts.get("image", 0) + + target_width = target_height = self.info.get_hf_config().projection_dim + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + return ProcessorInputs( + prompt_text="", + mm_data=mm_data, + ) + + +class Florence2MultiModalProcessor( + EncDecMultiModalProcessor[Florence2ProcessingInfo]): + + def _hf_processor_applies_repl( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> bool: + return False + + def create_encoder_prompt( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + ) -> Union[str, list[int]]: + return prompt + + def create_decoder_prompt( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + ) -> Union[str, list[int]]: + return [self.info.get_hf_config().eos_token_id] + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + if mm_data: + processed_outputs = super()._call_hf_processor( + prompt, mm_data, mm_kwargs) + else: + hf_processor = self.info.get_hf_processor() + tokenizer = hf_processor.tokenizer + prompt = hf_processor._construct_prompts([prompt])[0] + processed_outputs = tokenizer(prompt, + add_special_tokens=True, + return_tensors="pt") + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(pixel_values=MultiModalFieldConfig.batched("image")) + + def _get_prompt_replacements( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> list[PromptReplacement]: + hf_config = self.info.get_hf_config() + pad_token_id = hf_config.pad_token_id + bos_token_id = hf_config.bos_token_id + num_image_tokens = self.info.get_max_image_tokens() + image_tokens = [pad_token_id] * num_image_tokens + + return [ + PromptReplacement( + modality="image", + target=[bos_token_id], + replacement=PromptReplacementDetails( + full=image_tokens + [bos_token_id], + features=image_tokens, + ), + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor( + Florence2MultiModalProcessor, + info=Florence2ProcessingInfo, + dummy_inputs=Florence2DummyInputsBuilder) +class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config + processor_config = vllm_config.model_config.hf_image_processor_config - # TODO(Isotr0py): Add vision backbone + self.config = config + self.vision_config = config.vision_config + self.processor_config = processor_config + assert config.vision_config.model_type == 'davit', ( + 'only DaViT is supported for now') + self.vision_tower = DaViT.from_config(config=config.vision_config) + self._build_image_projection_layers(config) self.language_model = Florence2LanguageForConditionalGeneration( vllm_config=vllm_config.with_hf_config(config.text_config), prefix=f"{prefix}.language_model", ) + self.pad_token_id = config.pad_token_id - @property + def _build_image_projection_layers(self, config: PretrainedConfig): + image_dim_out = config.vision_config.dim_embed[-1] + dim_projection = config.vision_config.projection_dim + self.image_projection = nn.Parameter( + torch.empty(image_dim_out, dim_projection)) + self.image_proj_norm = nn.LayerNorm(dim_projection) + image_pos_embed_config = config.vision_config.image_pos_embed + if image_pos_embed_config['type'] == 'learned_abs_2d': + self.image_pos_embed = LearnedAbsolutePositionEmbedding2D( + embedding_dim=image_dim_out, + num_pos=image_pos_embed_config['max_pos_embeddings']) + else: + raise NotImplementedError("Florence2 only supports learned_abs_2d " + "as image position embedding.") + + self.image_feature_source = config.vision_config.image_feature_source + + # temporal embedding + visual_temporal_embedding_config = ( + self.vision_config.visual_temporal_embedding) + if visual_temporal_embedding_config['type'] == 'COSINE': + self.visual_temporal_embed = PositionalEmbeddingCosine1D( + embed_dim=image_dim_out, + max_seq_len=visual_temporal_embedding_config[ + 'max_temporal_embeddings']) + else: + raise NotImplementedError( + 'Florence2 only supports COSINE as temporal embedding.') + + @cached_property def sampler(self): - return self.language_model.sampler + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + return get_sampler() + + def _validate_pixel_values( + self, data: Union[torch.Tensor, List[torch.Tensor]] + ) -> Union[torch.Tensor, List[torch.Tensor]]: + + size = self.processor_config["size"] + h, w = size["height"], size["width"] + expected_dims = (3, h, w) + + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape) + + if actual_dims != expected_dims: + expected_expr = tuple(*map(str, expected_dims)) + raise ValueError( + "The expected shape of pixel values per batch " + f"is {expected_expr}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _parse_and_validate_image_input(self, **kwargs: object): + pixel_values: Optional[Union[List[List[torch.Tensor]], + List[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "pixel_values", None) + image_embeds: Optional[Union[List[List[torch.Tensor]], + List[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "image_embeds", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None and image_embeds is not None: + raise ValueError( + "Both pixel values and image embeds are provided.") + + if pixel_values is not None: + return Florence2ImagePixelInputs( + type="pixel_values", + data=self._validate_pixel_values( + flatten_bn(pixel_values, concat=True)), + ) + + if image_embeds is not None: + raise NotImplementedError + + raise AssertionError("This line should be unreachable.") + + def _encode_image(self, pixel_values: torch.Tensor) -> torch.Tensor: + dtype = next(self.vision_tower.parameters()).dtype + pixel_values = pixel_values.to(dtype) + + batch_size, T = pixel_values.size(0), 1 + x = self.vision_tower.forward_features_unpool(pixel_values) + if self.image_pos_embed is not None: + x = x.view(batch_size * T, -1, x.shape[-1]) + num_tokens = x.shape[-2] + h, w = int(num_tokens**0.5), int(num_tokens**0.5) + assert h * w == num_tokens, ( + 'only support square feature maps for now') + x = x.view(batch_size * T, h, w, x.shape[-1]) + pos_embed = self.image_pos_embed(x) + x = x + pos_embed + x = x.view(batch_size, T * h * w, x.shape[-1]) + + if self.visual_temporal_embed is not None: + visual_temporal_embed = self.visual_temporal_embed( + x.view(batch_size, T, -1, x.shape[-1])[:, :, 0]) + x = x.view(batch_size, T, -1, + x.shape[-1]) + visual_temporal_embed.view( + 1, T, 1, x.shape[-1]) + + x_feat_dict = {} + + spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2) + x_feat_dict['spatial_avg_pool'] = spatial_avg_pool_x + + temporal_avg_pool_x = x.view(batch_size, T, -1, + x.shape[-1]).mean(dim=1) + x_feat_dict['temporal_avg_pool'] = temporal_avg_pool_x + + x = x.view(batch_size, T, -1, x.shape[-1])[:, -1] + x_feat_dict['last_frame'] = x + + new_x = [] + for _image_feature_source in self.image_feature_source: + if _image_feature_source not in x_feat_dict: + raise ValueError('invalid image feature source: {}'.format( + _image_feature_source)) + new_x.append(x_feat_dict[_image_feature_source]) + + x = torch.cat(new_x, dim=1) + + x = x @ self.image_projection + x = self.image_proj_norm(x) + + return x + + def _process_image_input( + self, image_input: Florence2ImagePixelInputs) -> torch.Tensor: + assert image_input["type"] == "pixel_values" + pixel_values = image_input["data"] + return self._encode_image(pixel_values) + + def get_multimodal_embeddings(self, **kwargs: object) -> torch.Tensor: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.pad_token_id) + return inputs_embeds def forward( self, @@ -216,8 +1082,19 @@ class Florence2ForConditionalGeneration(nn.Module): Returns: Output torch.Tensor """ - return self.language_model(input_ids, positions, encoder_input_ids, - encoder_positions) + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + if encoder_input_ids.numel() > 0 or vision_embeddings is not None: + inputs_embeds = self.get_input_embeddings(encoder_input_ids, + vision_embeddings) + else: + inputs_embeds = None + + hidden_states = self.language_model(input_ids, + positions, + encoder_input_ids, + encoder_positions, + inputs_embeds=inputs_embeds) + return hidden_states def compute_logits( self, @@ -236,9 +1113,5 @@ class Florence2ForConditionalGeneration(nn.Module): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - skip_prefixes = [ - 'image_projection', "vision_tower", "image_proj_norm", - "image_pos_embed", "visual_temporal_embed" - ] - loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) + loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 58155905a7b7..75e31d557dd1 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -105,7 +105,6 @@ _TEXT_GENERATION_MODELS = { # [Encoder-decoder] "BartModel": ("bart", "BartForConditionalGeneration"), "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"), - "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501 } _EMBEDDING_MODELS = { @@ -182,6 +181,7 @@ _MULTIMODAL_MODELS = { "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501 "UltravoxModel": ("ultravox", "UltravoxModel"), # [Encoder-decoder] + "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501 "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501 "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501 } diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 93756364dea1..60b000e2b34f 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1303,6 +1303,14 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): """ raise NotImplementedError + def create_decoder_prompt( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + ) -> Union[str, list[int]]: + """Create input prompt for the decoder.""" + return prompt + def apply( self, prompt: Union[str, list[int]], @@ -1323,17 +1331,15 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): hf_processor_mm_kwargs, ) - # We assumed the decoder prompt text is copied from - # the original encoder prompt without extra process tokenizer = self.info.get_tokenizer() - if isinstance(prompt, str): - decoder_prompt = prompt + decoder_prompt = self.create_decoder_prompt(prompt, mm_data) + if isinstance(decoder_prompt, str): decoder_prompt_ids = encode_tokens(tokenizer, - prompt, + decoder_prompt, add_special_tokens=False) else: - decoder_prompt = decode_tokens(tokenizer, prompt) - decoder_prompt_ids = prompt + decoder_prompt_ids = decoder_prompt + decoder_prompt = decode_tokens(tokenizer, decoder_prompt) mm_inputs = MultiModalEncDecInputs( encoder_prompt=encoder_inputs["prompt"], diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index 093f8b7a8179..3178b0f8c3e6 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -204,9 +204,11 @@ class MultiModalProfiler(Generic[_I]): "and/or reduce `mm_counts`.", seq_len, total_len, total_placeholders_by_modality) + num_tokens_to_pad = max(total_len, seq_len) - total_len + prompt_token_ids.extend([0] * num_tokens_to_pad) + return DummyData( - seq_data=SequenceData.from_prompt_token_counts( - (0, max(seq_len, total_len))), + seq_data=SequenceData.from_seqs(prompt_token_ids), multi_modal_data=None, multi_modal_placeholders=None, )