mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 08:34:56 +08:00
[VLM] Support multimodal inputs for Florence-2 models (#13320)
This commit is contained in:
parent
788f284b53
commit
edf309ebbe
@ -715,6 +715,13 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
*
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
- * `Florence2ForConditionalGeneration`
|
||||
* Florence-2
|
||||
* T + I
|
||||
* `microsoft/Florence-2-base`, `microsoft/Florence-2-large` etc.
|
||||
*
|
||||
*
|
||||
*
|
||||
- * `FuyuForCausalLM`
|
||||
* Fuyu
|
||||
* T + I
|
||||
|
||||
@ -1,34 +1,45 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
'''
|
||||
"""
|
||||
Demonstrate prompting of text-to-text
|
||||
encoder/decoder models, specifically Florence-2
|
||||
'''
|
||||
"""
|
||||
# TODO(Isotr0py):
|
||||
# Move to offline_inference/vision_language.py
|
||||
# after porting vision backbone
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
dtype = "float"
|
||||
from vllm.assets.image import ImageAsset
|
||||
|
||||
# Create a Florence-2 encoder/decoder model instance
|
||||
llm = LLM(
|
||||
model="microsoft/Florence-2-base",
|
||||
tokenizer="facebook/bart-base",
|
||||
dtype=dtype,
|
||||
model="microsoft/Florence-2-large",
|
||||
tokenizer="facebook/bart-large",
|
||||
max_num_seqs=8,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
prompts = [
|
||||
"<CAPTION>", "<DETAILED_CAPTION>", "<MORE_DETAILED_CAPTION>",
|
||||
"<CAPTION_TO_PHRASE_GROUNDING>", "<OD>", "<DENSE_REGION_CAPTION>",
|
||||
"<REGION_PROPOSAL>", "<OCR>", "<OCR_WITH_REGION>"
|
||||
{ # implicit prompt with task token
|
||||
"prompt": "<DETAILED_CAPTION>",
|
||||
"multi_modal_data": {
|
||||
"image": ImageAsset("stop_sign").pil_image
|
||||
},
|
||||
},
|
||||
{ # explicit encoder/decoder prompt
|
||||
"encoder_prompt": {
|
||||
"prompt": "Describe in detail what is shown in the image.",
|
||||
"multi_modal_data": {
|
||||
"image": ImageAsset("cherry_blossom").pil_image
|
||||
},
|
||||
},
|
||||
"decoder_prompt": "",
|
||||
},
|
||||
]
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
top_p=1.0,
|
||||
min_tokens=0,
|
||||
max_tokens=20,
|
||||
max_tokens=128,
|
||||
)
|
||||
|
||||
# Generate output tokens from the prompts. The output is a list of
|
||||
@ -38,9 +49,5 @@ outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
encoder_prompt = output.encoder_prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Encoder prompt: {encoder_prompt!r}, "
|
||||
f"Decoder prompt: {prompt!r}, "
|
||||
f"Generated text: {generated_text!r}")
|
||||
print(f"Generated text: {generated_text!r}")
|
||||
|
||||
@ -82,6 +82,22 @@ def run_deepseek_vl2(question: str, modality: str):
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
|
||||
# Florence2
|
||||
def run_florence2(question: str, modality: str):
|
||||
assert modality == "image"
|
||||
|
||||
llm = LLM(model="microsoft/Florence-2-large",
|
||||
tokenizer="facebook/bart-large",
|
||||
max_num_seqs=8,
|
||||
trust_remote_code=True,
|
||||
dtype="bfloat16",
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||
|
||||
prompt = "<MORE_DETAILED_CAPTION>"
|
||||
stop_token_ids = None
|
||||
return llm, prompt, stop_token_ids
|
||||
|
||||
|
||||
# Fuyu
|
||||
def run_fuyu(question: str, modality: str):
|
||||
assert modality == "image"
|
||||
@ -571,6 +587,7 @@ model_example_map = {
|
||||
"blip-2": run_blip2,
|
||||
"chameleon": run_chameleon,
|
||||
"deepseek_vl_v2": run_deepseek_vl2,
|
||||
"florence2": run_florence2,
|
||||
"fuyu": run_fuyu,
|
||||
"glm4v": run_glm4v,
|
||||
"h2ovl_chat": run_h2ovl,
|
||||
|
||||
@ -600,8 +600,8 @@ class HfRunner:
|
||||
if images is not None and images[i] is not None:
|
||||
processor_kwargs["images"] = images[i]
|
||||
|
||||
encoder_input_ids = self.wrap_device(
|
||||
self.processor(**processor_kwargs).input_ids,
|
||||
encoder_inputs = self.wrap_device(
|
||||
self.processor(**processor_kwargs),
|
||||
device=self.model.device.type,
|
||||
)
|
||||
|
||||
@ -615,13 +615,13 @@ class HfRunner:
|
||||
)
|
||||
|
||||
output = self.model.generate(
|
||||
encoder_input_ids,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
use_cache=True,
|
||||
do_sample=False,
|
||||
max_new_tokens=max_tokens,
|
||||
output_hidden_states=True,
|
||||
return_dict_in_generate=True,
|
||||
**encoder_inputs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@ -15,7 +15,7 @@ from ....conftest import HfRunner, VllmRunner
|
||||
from ....utils import RemoteOpenAIServer
|
||||
from ...utils import check_logprobs_close
|
||||
|
||||
MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
|
||||
MODEL_NAME = "fixie-ai/ultravox-v0_4"
|
||||
|
||||
AudioTuple = Tuple[np.ndarray, int]
|
||||
|
||||
@ -187,7 +187,7 @@ def run_multi_audio_test(
|
||||
|
||||
|
||||
@pytest.mark.core_model
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("vllm_kwargs", [
|
||||
|
||||
@ -1,52 +1,59 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from functools import partial
|
||||
from typing import List, Optional, Tuple, Type
|
||||
from typing import Optional, Type
|
||||
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
from vllm.inputs.data import ExplicitEncoderDecoderPrompt
|
||||
from vllm.inputs.data import ExplicitEncoderDecoderPrompt, TextPrompt
|
||||
from vllm.multimodal.image import rescale_image_size
|
||||
from vllm.sequence import SampleLogprobs
|
||||
|
||||
from ....conftest import HfRunner, VllmRunner
|
||||
from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
|
||||
from ...utils import check_logprobs_close
|
||||
|
||||
Florence2Prompt = partial(ExplicitEncoderDecoderPrompt,
|
||||
decoder_prompt=None,
|
||||
mm_processor_kwargs=None)
|
||||
|
||||
MODELS = ["microsoft/Florence-2-base"]
|
||||
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
|
||||
# Therefore, we borrow the BartTokenizer from the original Bart model
|
||||
TOKENIZER = "facebook/bart-base"
|
||||
PROMPTS = [
|
||||
Florence2Prompt(encoder_prompt="<CAPTION>"),
|
||||
Florence2Prompt(encoder_prompt="<DETAILED_CAPTION>"),
|
||||
Florence2Prompt(encoder_prompt="<MORE_DETAILED_CAPTION>"),
|
||||
Florence2Prompt(encoder_prompt="<CAPTION_TO_PHRASE_GROUNDING>"),
|
||||
Florence2Prompt(encoder_prompt="<DENSE_REGION_CAPTION>"),
|
||||
Florence2Prompt(encoder_prompt="<REGION_PROPOSAL>"),
|
||||
Florence2Prompt(encoder_prompt="<OCR_WITH_REGION>"),
|
||||
Florence2Prompt(encoder_prompt="<OCR>"),
|
||||
Florence2Prompt(encoder_prompt="<OD>"),
|
||||
]
|
||||
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||
"stop_sign":
|
||||
"<CAPTION>", # special task token
|
||||
"cherry_blossom":
|
||||
"Describe in detail what is shown in the image.",
|
||||
})
|
||||
|
||||
|
||||
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
||||
Optional[SampleLogprobs]], ):
|
||||
"""Sanitize vllm output to be comparable with hf output."""
|
||||
output_ids, output_str, out_logprobs = vllm_output
|
||||
def get_hf_images_prompts(
|
||||
prompts_: list[ExplicitEncoderDecoderPrompt[str, TextPrompt]],
|
||||
) -> tuple[list[ExplicitEncoderDecoderPrompt[str, str]], list[Image.Image]]:
|
||||
prompts, images = [], []
|
||||
for prompt in prompts_:
|
||||
encoder_prompt = prompt["encoder_prompt"]
|
||||
prompts.append(
|
||||
ExplicitEncoderDecoderPrompt(
|
||||
encoder_prompt=encoder_prompt["prompt"],
|
||||
decoder_prompt=None,
|
||||
))
|
||||
images.append(encoder_prompt["multi_modal_data"]["image"])
|
||||
return prompts, images
|
||||
|
||||
hf_output_str = "</s><s>" + output_str + "</s>"
|
||||
|
||||
return output_ids, hf_output_str, out_logprobs
|
||||
def hf_to_vllm_output(hf_output: tuple[list[int], str,
|
||||
Optional[SampleLogprobs]]):
|
||||
"""Sanitize hf output to be comparable with vllm output."""
|
||||
output_ids, output_str, out_logprobs = hf_output
|
||||
|
||||
output_str = output_str.replace("</s>", "").replace("<s>", "")
|
||||
output_ids = [ids for ids in output_ids if ids not in [0, 2]]
|
||||
|
||||
return output_ids, output_str, out_logprobs
|
||||
|
||||
|
||||
def run_test(
|
||||
hf_runner: Type[HfRunner],
|
||||
vllm_runner: Type[VllmRunner],
|
||||
prompts: List[ExplicitEncoderDecoderPrompt],
|
||||
inputs: list[list[ExplicitEncoderDecoderPrompt]],
|
||||
model: str,
|
||||
*,
|
||||
dtype: str,
|
||||
@ -56,46 +63,76 @@ def run_test(
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
) -> None:
|
||||
with vllm_runner(model,
|
||||
max_num_seqs=8,
|
||||
tokenizer_name=TOKENIZER,
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=True) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
|
||||
prompts, max_tokens, num_logprobs)
|
||||
vllm_outputs_per_case = [
|
||||
vllm_model.generate_encoder_decoder_greedy_logprobs(
|
||||
prompts, max_tokens, num_logprobs=num_logprobs)
|
||||
for prompts in inputs
|
||||
]
|
||||
|
||||
hf_inputs = [get_hf_images_prompts(prompts) for prompts in inputs]
|
||||
|
||||
# Florence-2 processors require image inputs
|
||||
dummy_image = Image.new(mode="RGB", size=(2, 2))
|
||||
with hf_runner(model, dtype=dtype, skip_tokenizer_init=True) as hf_model:
|
||||
hf_model.model.get_output_embeddings = lambda: \
|
||||
hf_model.model.language_model.lm_head
|
||||
hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit(
|
||||
prompts,
|
||||
max_tokens,
|
||||
num_logprobs,
|
||||
images=[dummy_image] * len(prompts),
|
||||
))
|
||||
hf_outputs_per_case = [
|
||||
hf_model.generate_encoder_decoder_greedy_logprobs_limit(
|
||||
prompts, max_tokens, num_logprobs=num_logprobs, images=images)
|
||||
for prompts, images in hf_inputs
|
||||
]
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=[
|
||||
vllm_to_hf_output(vllm_output) for vllm_output in vllm_outputs
|
||||
],
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
for hf_outputs, vllm_outputs in zip(hf_outputs_per_case,
|
||||
vllm_outputs_per_case):
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=[hf_to_vllm_output(output) for output in hf_outputs],
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.core_model
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float", "bfloat16"])
|
||||
@pytest.mark.parametrize(
|
||||
"size_factors",
|
||||
[
|
||||
# No image
|
||||
[],
|
||||
# Single-scale
|
||||
[1.0],
|
||||
# Single-scale, batched
|
||||
[1.0, 1.0, 1.0],
|
||||
# Multi-scale
|
||||
[0.25, 0.5, 1.0],
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_models(hf_runner, vllm_runner, model, dtype, max_tokens,
|
||||
num_logprobs) -> None:
|
||||
def test_models(hf_runner: Type[HfRunner], vllm_runner: Type[VllmRunner],
|
||||
image_assets: _ImageAssets, model: str,
|
||||
size_factors: list[int], dtype: str, max_tokens: int,
|
||||
num_logprobs: int) -> None:
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
inputs_per_image = [[
|
||||
ExplicitEncoderDecoderPrompt(
|
||||
encoder_prompt=TextPrompt(
|
||||
prompt=prompt,
|
||||
multi_modal_data={"image": rescale_image_size(image, factor)}),
|
||||
decoder_prompt=None,
|
||||
) for factor in size_factors
|
||||
] for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
||||
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
PROMPTS,
|
||||
inputs_per_image,
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
|
||||
@ -29,8 +29,8 @@ def _test_processing_correctness(
|
||||
model_config = ModelConfig(
|
||||
model_id,
|
||||
task="auto",
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
tokenizer=model_info.tokenizer or model_id,
|
||||
tokenizer_mode=model_info.tokenizer_mode,
|
||||
trust_remote_code=model_info.trust_remote_code,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
@ -151,6 +151,7 @@ def _test_processing_correctness(
|
||||
"Salesforce/blip2-opt-2.7b",
|
||||
"facebook/chameleon-7b",
|
||||
"deepseek-ai/deepseek-vl2-tiny",
|
||||
"microsoft/Florence-2-base",
|
||||
"adept/fuyu-8b",
|
||||
"THUDM/glm-4v-9b",
|
||||
"h2oai/h2ovl-mississippi-800m",
|
||||
|
||||
@ -193,11 +193,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
# [Encoder-decoder]
|
||||
"BartModel": _HfExamplesInfo("facebook/bart-base"),
|
||||
"BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"),
|
||||
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
|
||||
# Therefore, we borrow the BartTokenizer from the original Bart model
|
||||
"Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501
|
||||
tokenizer="facebook/bart-base",
|
||||
trust_remote_code=True), # noqa: E501
|
||||
}
|
||||
|
||||
_EMBEDDING_EXAMPLE_MODELS = {
|
||||
@ -288,6 +283,11 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
extras={"v0.5": "fixie-ai/ultravox-v0_5-llama-3_2-1b"}, # noqa: E501
|
||||
trust_remote_code=True),
|
||||
# [Encoder-decoder]
|
||||
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
|
||||
# Therefore, we borrow the BartTokenizer from the original Bart model
|
||||
"Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501
|
||||
tokenizer="facebook/bart-base",
|
||||
trust_remote_code=True), # noqa: E501
|
||||
"MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501
|
||||
"WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501
|
||||
}
|
||||
|
||||
@ -588,8 +588,12 @@ class BartEncoder(nn.Module):
|
||||
|
||||
self.layernorm_embedding = nn.LayerNorm(embed_dim)
|
||||
|
||||
def forward(self, input_ids: torch.Tensor,
|
||||
positions: torch.Tensor) -> torch.Tensor:
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
input_ids
|
||||
@ -602,7 +606,8 @@ class BartEncoder(nn.Module):
|
||||
Decoder output torch.Tensor
|
||||
"""
|
||||
# retrieve input_ids and inputs_embeds
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
embed_pos = self.embed_positions(positions)
|
||||
embed_pos = embed_pos.to(inputs_embeds.device)
|
||||
@ -661,9 +666,13 @@ class BartDecoder(nn.Module):
|
||||
|
||||
self.layernorm_embedding = nn.LayerNorm(config.d_model)
|
||||
|
||||
def forward(self, decoder_input_ids: torch.Tensor,
|
||||
decoder_positions: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
def forward(
|
||||
self,
|
||||
decoder_input_ids: torch.Tensor,
|
||||
decoder_positions: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
decoder_input_ids
|
||||
@ -677,8 +686,10 @@ class BartDecoder(nn.Module):
|
||||
Returns:
|
||||
Decoder output torch.Tensor
|
||||
"""
|
||||
|
||||
inputs_embeds = self.embed_tokens(decoder_input_ids)
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(decoder_input_ids)
|
||||
else:
|
||||
decoder_positions = inputs_embeds[:, -1]
|
||||
|
||||
# embed positions
|
||||
embed_pos = self.embed_positions(decoder_positions)
|
||||
|
||||
@ -1,10 +1,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import math
|
||||
from typing import Iterable, Optional, Set, Tuple
|
||||
from functools import cached_property
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, OrderedDict,
|
||||
Set, Tuple, TypedDict, Union)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from transformers import BatchFeature, PretrainedConfig
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
@ -14,11 +19,567 @@ from vllm.model_executor.models.bart import (BartDecoder, BartEncoder,
|
||||
BartParallelLMHead,
|
||||
BartScaledWordEmbedding)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
|
||||
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
|
||||
from vllm.multimodal.parse import MultiModalDataDict, MultiModalDataItems
|
||||
from vllm.multimodal.processing import (BaseProcessingInfo,
|
||||
EncDecMultiModalProcessor,
|
||||
PromptReplacement,
|
||||
PromptReplacementDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .utils import AutoWeightsLoader
|
||||
from .interfaces import SupportsMultiModal
|
||||
from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings
|
||||
|
||||
|
||||
class Florence2ImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: torch.Tensor
|
||||
"""Shape: (batch_size, num_channel, height, width)"""
|
||||
|
||||
|
||||
# ViT implementation are all copied from
|
||||
# https://huggingface.co/microsoft/Florence-2-base/blob/main/modeling_florence2.py
|
||||
class LearnedAbsolutePositionEmbedding2D(nn.Module):
|
||||
"""
|
||||
This module learns positional embeddings up to a fixed maximum size.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim=256, num_pos=50):
|
||||
super().__init__()
|
||||
self.row_embeddings = nn.Embedding(num_pos, embedding_dim // 2)
|
||||
self.column_embeddings = nn.Embedding(
|
||||
num_pos, embedding_dim - (embedding_dim // 2))
|
||||
|
||||
def forward(self, pixel_values):
|
||||
"""
|
||||
pixel_values: (batch_size, height, width, num_channels)
|
||||
returns: (batch_size, height, width, embedding_dim * 2)
|
||||
"""
|
||||
if len(pixel_values.shape) != 4:
|
||||
raise ValueError('pixel_values must be a 4D tensor')
|
||||
height, width = pixel_values.shape[1:3]
|
||||
width_values = torch.arange(width, device=pixel_values.device)
|
||||
height_values = torch.arange(height, device=pixel_values.device)
|
||||
x_emb = self.column_embeddings(width_values)
|
||||
y_emb = self.row_embeddings(height_values)
|
||||
# (height, width, embedding_dim * 2)
|
||||
pos = torch.cat([
|
||||
x_emb.unsqueeze(0).repeat(height, 1, 1),
|
||||
y_emb.unsqueeze(1).repeat(1, width, 1)
|
||||
],
|
||||
dim=-1)
|
||||
# (embedding_dim * 2, height, width)
|
||||
pos = pos.permute(2, 0, 1)
|
||||
pos = pos.unsqueeze(0)
|
||||
# (batch_size, embedding_dim * 2, height, width)
|
||||
pos = pos.repeat(pixel_values.shape[0], 1, 1, 1)
|
||||
# (batch_size, height, width, embedding_dim * 2)
|
||||
pos = pos.permute(0, 2, 3, 1)
|
||||
return pos
|
||||
|
||||
|
||||
class PositionalEmbeddingCosine1D(nn.Module):
|
||||
"""
|
||||
This class implements a very simple positional encoding. It follows closely
|
||||
the encoder from the link below:
|
||||
https://pytorch.org/tutorials/beginner/translation_transformer.html
|
||||
Args:
|
||||
embed_dim: The dimension of the embeddings.
|
||||
dropout_prob: The dropout probability.
|
||||
max_seq_len: The maximum length to precompute the positional encodings.
|
||||
"""
|
||||
|
||||
def __init__(self, embed_dim: int = 512, max_seq_len: int = 1024) -> None:
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.max_seq_len = max_seq_len
|
||||
# Generate the sinusoidal arrays.
|
||||
factor = math.log(10000)
|
||||
denominator = torch.exp(-factor * torch.arange(0, self.embed_dim, 2) /
|
||||
self.embed_dim)
|
||||
# Matrix where rows correspond to a positional embedding as a function
|
||||
# of the position index (i.e., the row index).
|
||||
frequencies = \
|
||||
torch.arange(0, self.max_seq_len) \
|
||||
.reshape(self.max_seq_len, 1) * denominator
|
||||
pos_idx_to_embed = torch.zeros((self.max_seq_len, self.embed_dim))
|
||||
# Populate uneven entries.
|
||||
pos_idx_to_embed[:, 0::2] = torch.sin(frequencies)
|
||||
pos_idx_to_embed[:, 1::2] = torch.cos(frequencies)
|
||||
# Save the positional embeddings in a constant buffer.
|
||||
# self.register_buffer("pos_idx_to_embed", pos_idx_to_embed)
|
||||
self.pos_idx_to_embed = nn.Parameter(pos_idx_to_embed,
|
||||
requires_grad=False)
|
||||
|
||||
def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
seq_embeds: The sequence embeddings in order. Allowed size:
|
||||
1. [T, D], where T is the length of the sequence, and D is the
|
||||
frame embedding dimension.
|
||||
2. [B, T, D], where B is the batch size and T and D are the
|
||||
same as above.
|
||||
Returns a tensor of with the same dimensions as the input: i.e.,
|
||||
[1, T, D] or [T, D].
|
||||
"""
|
||||
shape_len = len(seq_embeds.shape)
|
||||
assert 2 <= shape_len <= 3
|
||||
len_seq = seq_embeds.size(-2)
|
||||
assert len_seq <= self.max_seq_len
|
||||
pos_embeds = self.pos_idx_to_embed[0:seq_embeds.size(-2), :]
|
||||
# Adapt pre-computed positional embeddings to the input.
|
||||
if shape_len == 3:
|
||||
pos_embeds = pos_embeds.view(
|
||||
(1, pos_embeds.size(0), pos_embeds.size(1)))
|
||||
return pos_embeds
|
||||
|
||||
|
||||
class MySequential(nn.Sequential):
|
||||
|
||||
def forward(self, *inputs):
|
||||
for module in self._modules.values():
|
||||
if isinstance(inputs, tuple):
|
||||
inputs = module(*inputs)
|
||||
else:
|
||||
inputs = module(inputs)
|
||||
return inputs
|
||||
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
|
||||
def __init__(self, norm, fn):
|
||||
super().__init__()
|
||||
self.norm = norm
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
shortcut = x
|
||||
if self.norm is not None:
|
||||
x, size = self.fn(self.norm(x), *args, **kwargs)
|
||||
else:
|
||||
x, size = self.fn(x, *args, **kwargs)
|
||||
|
||||
x = shortcut + x
|
||||
|
||||
return x, size
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.net = nn.Sequential(
|
||||
OrderedDict([("fc1", nn.Linear(in_features, hidden_features)),
|
||||
("act", act_layer()),
|
||||
("fc2", nn.Linear(hidden_features, out_features))]))
|
||||
|
||||
def forward(self, x, size):
|
||||
return self.net(x), size
|
||||
|
||||
|
||||
class DepthWiseConv2d(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_in,
|
||||
kernel_size,
|
||||
padding,
|
||||
stride,
|
||||
bias=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.dw = nn.Conv2d(dim_in,
|
||||
dim_in,
|
||||
kernel_size=kernel_size,
|
||||
padding=padding,
|
||||
groups=dim_in,
|
||||
stride=stride,
|
||||
bias=bias)
|
||||
|
||||
def forward(self, x, size):
|
||||
B, N, C = x.shape
|
||||
H, W = size
|
||||
assert N == H * W
|
||||
|
||||
x = self.dw(x.transpose(1, 2).view(B, C, H, W))
|
||||
size = (x.size(-2), x.size(-1))
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
return x, size
|
||||
|
||||
|
||||
class ConvEmbed(nn.Module):
|
||||
""" Image to Patch Embedding
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
patch_size=7,
|
||||
in_chans=3,
|
||||
embed_dim=64,
|
||||
stride=4,
|
||||
padding=2,
|
||||
norm_layer=None,
|
||||
pre_norm=True):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.proj = nn.Conv2d(in_chans,
|
||||
embed_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=stride,
|
||||
padding=padding)
|
||||
|
||||
dim_norm = in_chans if pre_norm else embed_dim
|
||||
self.norm = norm_layer(dim_norm) if norm_layer else None
|
||||
|
||||
self.pre_norm = pre_norm
|
||||
|
||||
def forward(self, x, size):
|
||||
H, W = size
|
||||
if len(x.size()) == 3:
|
||||
if self.norm and self.pre_norm:
|
||||
x = self.norm(x)
|
||||
x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W)
|
||||
|
||||
x = self.proj(x)
|
||||
|
||||
_, _, H, W = x.shape
|
||||
x = rearrange(x, 'b c h w -> b (h w) c')
|
||||
if self.norm and not self.pre_norm:
|
||||
x = self.norm(x)
|
||||
|
||||
return x, (H, W)
|
||||
|
||||
|
||||
class ChannelAttention(nn.Module):
|
||||
|
||||
def __init__(self, dim, groups=8, qkv_bias=True):
|
||||
super().__init__()
|
||||
|
||||
self.groups = groups
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
def forward(self, x, size):
|
||||
B, N, C = x.shape
|
||||
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.groups,
|
||||
C // self.groups).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
q = q * (float(N)**-0.5)
|
||||
attention = q.transpose(-1, -2) @ k
|
||||
attention = attention.softmax(dim=-1)
|
||||
x = (attention @ v.transpose(-1, -2)).transpose(-1, -2)
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
return x, size
|
||||
|
||||
|
||||
class ChannelBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
groups,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
drop_path_rate=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
conv_at_attn=True,
|
||||
conv_at_ffn=True):
|
||||
super().__init__()
|
||||
|
||||
self.conv1 = PreNorm(None, DepthWiseConv2d(
|
||||
dim, 3, 1, 1)) if conv_at_attn else None
|
||||
self.channel_attn = PreNorm(
|
||||
norm_layer(dim),
|
||||
ChannelAttention(dim, groups=groups, qkv_bias=qkv_bias),
|
||||
)
|
||||
self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1,
|
||||
1)) if conv_at_ffn else None
|
||||
self.ffn = PreNorm(
|
||||
norm_layer(dim),
|
||||
Mlp(in_features=dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=act_layer),
|
||||
)
|
||||
|
||||
def forward(self, x, size):
|
||||
if self.conv1:
|
||||
x, size = self.conv1(x, size)
|
||||
x, size = self.channel_attn(x, size)
|
||||
|
||||
if self.conv2:
|
||||
x, size = self.conv2(x, size)
|
||||
x, size = self.ffn(x, size)
|
||||
|
||||
return x, size
|
||||
|
||||
|
||||
def window_partition(x, window_size: int):
|
||||
B, H, W, C = x.shape
|
||||
x = x.view(B, H // window_size, window_size, W // window_size, window_size,
|
||||
C)
|
||||
windows = x.permute(0, 1, 3, 2, 4,
|
||||
5).contiguous().view(-1, window_size, window_size, C)
|
||||
return windows
|
||||
|
||||
|
||||
def window_reverse(windows, batch_size: int, window_size: int, H: int, W: int):
|
||||
B = batch_size
|
||||
|
||||
x = windows.view(B, H // window_size, W // window_size, window_size,
|
||||
window_size, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
||||
return x
|
||||
|
||||
|
||||
class WindowAttention(nn.Module):
|
||||
|
||||
def __init__(self, dim, num_heads, window_size, qkv_bias=True):
|
||||
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.window_size = window_size
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = float(head_dim)**-0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def forward(self, x, size):
|
||||
|
||||
H, W = size
|
||||
B, L, C = x.shape
|
||||
assert L == H * W, "input feature has wrong size"
|
||||
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
pad_l = pad_t = 0
|
||||
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
||||
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
||||
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
||||
_, Hp, Wp, _ = x.shape
|
||||
|
||||
x = window_partition(x, self.window_size)
|
||||
x = x.view(-1, self.window_size * self.window_size, C)
|
||||
|
||||
# W-MSA/SW-MSA
|
||||
# attn_windows = self.attn(x_windows)
|
||||
|
||||
B_, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
|
||||
C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
attn = self.softmax(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
||||
x = self.proj(x)
|
||||
|
||||
# merge windows
|
||||
x = x.view(-1, self.window_size, self.window_size, C)
|
||||
x = window_reverse(x, B, self.window_size, Hp, Wp)
|
||||
|
||||
if pad_r > 0 or pad_b > 0:
|
||||
x = x[:, :H, :W, :].contiguous()
|
||||
|
||||
x = x.view(B, H * W, C)
|
||||
|
||||
return x, size
|
||||
|
||||
|
||||
class SpatialBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
num_heads,
|
||||
window_size,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
drop_path_rate=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
conv_at_attn=True,
|
||||
conv_at_ffn=True):
|
||||
super().__init__()
|
||||
|
||||
self.conv1 = PreNorm(None, DepthWiseConv2d(
|
||||
dim, 3, 1, 1)) if conv_at_attn else None
|
||||
self.window_attn = PreNorm(
|
||||
norm_layer(dim),
|
||||
WindowAttention(dim, num_heads, window_size, qkv_bias=qkv_bias),
|
||||
)
|
||||
self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1,
|
||||
1)) if conv_at_ffn else None
|
||||
self.ffn = PreNorm(
|
||||
norm_layer(dim),
|
||||
Mlp(in_features=dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=act_layer),
|
||||
)
|
||||
|
||||
def forward(self, x, size):
|
||||
if self.conv1:
|
||||
x, size = self.conv1(x, size)
|
||||
x, size = self.window_attn(x, size)
|
||||
|
||||
if self.conv2:
|
||||
x, size = self.conv2(x, size)
|
||||
x, size = self.ffn(x, size)
|
||||
return x, size
|
||||
|
||||
|
||||
class DaViT(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
depths=(1, 1, 3, 1),
|
||||
patch_size=(7, 2, 2, 2),
|
||||
patch_stride=(4, 2, 2, 2),
|
||||
patch_padding=(3, 0, 0, 0),
|
||||
patch_prenorm=(False, False, False, False),
|
||||
embed_dims=(64, 128, 192, 256),
|
||||
num_heads=(3, 6, 12, 24),
|
||||
num_groups=(3, 6, 12, 24),
|
||||
window_size=7,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
drop_path_rate=0.1,
|
||||
norm_layer=nn.LayerNorm,
|
||||
enable_checkpoint=False,
|
||||
conv_at_attn=True,
|
||||
conv_at_ffn=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.embed_dims = embed_dims
|
||||
self.num_heads = num_heads
|
||||
self.num_groups = num_groups
|
||||
self.num_stages = len(self.embed_dims)
|
||||
self.enable_checkpoint = enable_checkpoint
|
||||
assert self.num_stages == len(self.num_heads) == len(self.num_groups)
|
||||
|
||||
num_stages = len(embed_dims)
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate,
|
||||
sum(depths) * 2)
|
||||
]
|
||||
|
||||
depth_offset = 0
|
||||
convs = []
|
||||
blocks = []
|
||||
for i in range(num_stages):
|
||||
conv_embed = ConvEmbed(
|
||||
patch_size=patch_size[i],
|
||||
stride=patch_stride[i],
|
||||
padding=patch_padding[i],
|
||||
in_chans=in_chans if i == 0 else self.embed_dims[i - 1],
|
||||
embed_dim=self.embed_dims[i],
|
||||
norm_layer=norm_layer,
|
||||
pre_norm=patch_prenorm[i])
|
||||
convs.append(conv_embed)
|
||||
|
||||
block = MySequential(*[
|
||||
MySequential(
|
||||
OrderedDict([('spatial_block',
|
||||
SpatialBlock(
|
||||
embed_dims[i],
|
||||
num_heads[i],
|
||||
window_size,
|
||||
drop_path_rate=dpr[depth_offset + j * 2],
|
||||
qkv_bias=qkv_bias,
|
||||
mlp_ratio=mlp_ratio,
|
||||
conv_at_attn=conv_at_attn,
|
||||
conv_at_ffn=conv_at_ffn,
|
||||
)),
|
||||
('channel_block',
|
||||
ChannelBlock(
|
||||
embed_dims[i],
|
||||
num_groups[i],
|
||||
drop_path_rate=dpr[depth_offset + j * 2 +
|
||||
1],
|
||||
qkv_bias=qkv_bias,
|
||||
mlp_ratio=mlp_ratio,
|
||||
conv_at_attn=conv_at_attn,
|
||||
conv_at_ffn=conv_at_ffn,
|
||||
))])) for j in range(depths[i])
|
||||
])
|
||||
blocks.append(block)
|
||||
depth_offset += depths[i] * 2
|
||||
|
||||
self.convs = nn.ModuleList(convs)
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
|
||||
self.avgpool = nn.AdaptiveAvgPool1d(1)
|
||||
|
||||
@property
|
||||
def dim_out(self):
|
||||
return self.embed_dims[-1]
|
||||
|
||||
def forward_features_unpool(self, x):
|
||||
"""
|
||||
forward until avg pooling
|
||||
Args:
|
||||
x (_type_): input image tensor
|
||||
"""
|
||||
input_size = (x.size(2), x.size(3))
|
||||
for conv, block in zip(self.convs, self.blocks):
|
||||
x, input_size = conv(x, input_size)
|
||||
x, input_size = block(x, input_size)
|
||||
return x
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.forward_features_unpool(x)
|
||||
|
||||
# (batch_size, num_tokens, token_dim)
|
||||
x = self.avgpool(x.transpose(1, 2))
|
||||
# (batch_size, 1, num_tokens)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.norms(x)
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
return cls(
|
||||
depths=config.depths,
|
||||
embed_dims=config.dim_embed,
|
||||
num_heads=config.num_heads,
|
||||
num_groups=config.num_groups,
|
||||
patch_size=config.patch_size,
|
||||
patch_stride=config.patch_stride,
|
||||
patch_padding=config.patch_padding,
|
||||
patch_prenorm=config.patch_prenorm,
|
||||
drop_path_rate=config.drop_path_rate,
|
||||
window_size=config.window_size,
|
||||
)
|
||||
|
||||
|
||||
# Language backbone and processor implementation
|
||||
class Florence2LanguageModel(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
@ -47,9 +608,14 @@ class Florence2LanguageModel(nn.Module):
|
||||
self.encoder.embed_tokens.weight = self.shared.weight
|
||||
self.decoder.embed_tokens.weight = self.shared.weight
|
||||
|
||||
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
|
||||
encoder_input_ids: torch.Tensor,
|
||||
encoder_positions: torch.Tensor) -> torch.Tensor:
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
encoder_input_ids: torch.Tensor,
|
||||
encoder_positions: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
input_ids
|
||||
@ -68,11 +634,12 @@ class Florence2LanguageModel(nn.Module):
|
||||
|
||||
encoder_hidden_states = None
|
||||
|
||||
if encoder_input_ids.numel() > 0:
|
||||
if inputs_embeds is not None or encoder_input_ids.numel() > 0:
|
||||
# Run encoder attention if a non-zero number of encoder tokens
|
||||
# are provided as input
|
||||
encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
|
||||
positions=encoder_positions)
|
||||
positions=encoder_positions,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
# decoder outputs consists of
|
||||
# (dec_features, past_key_value, dec_hidden, dec_attn)
|
||||
@ -112,6 +679,7 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
encoder_input_ids: torch.Tensor,
|
||||
encoder_positions: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
@ -127,8 +695,15 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
|
||||
Returns:
|
||||
Output torch.Tensor
|
||||
"""
|
||||
return self.model(input_ids, positions, encoder_input_ids,
|
||||
encoder_positions)
|
||||
|
||||
return self.model(input_ids,
|
||||
positions,
|
||||
encoder_input_ids,
|
||||
encoder_positions,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.encoder.embed_tokens(input_ids)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
@ -177,21 +752,312 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class Florence2ForConditionalGeneration(nn.Module):
|
||||
class Florence2ProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_hf_config(self):
|
||||
return self.ctx.get_hf_config()
|
||||
|
||||
def get_hf_processor(self):
|
||||
return self.ctx.get_hf_processor()
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": 1}
|
||||
|
||||
def get_max_image_tokens(self) -> int:
|
||||
processor_config = self.ctx.get_hf_image_processor_config()
|
||||
return processor_config["image_seq_length"]
|
||||
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
return {"image": self.get_max_image_tokens()}
|
||||
|
||||
|
||||
class Florence2DummyInputsBuilder(
|
||||
BaseDummyInputsBuilder[Florence2ProcessingInfo]):
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
target_width = target_height = self.info.get_hf_config().projection_dim
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text="",
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
class Florence2MultiModalProcessor(
|
||||
EncDecMultiModalProcessor[Florence2ProcessingInfo]):
|
||||
|
||||
def _hf_processor_applies_repl(
|
||||
self,
|
||||
prompt_text: str,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> bool:
|
||||
return False
|
||||
|
||||
def create_encoder_prompt(
|
||||
self,
|
||||
prompt: Union[str, list[int]],
|
||||
mm_data: MultiModalDataDict,
|
||||
) -> Union[str, list[int]]:
|
||||
return prompt
|
||||
|
||||
def create_decoder_prompt(
|
||||
self,
|
||||
prompt: Union[str, list[int]],
|
||||
mm_data: MultiModalDataDict,
|
||||
) -> Union[str, list[int]]:
|
||||
return [self.info.get_hf_config().eos_token_id]
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
if mm_data:
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
prompt, mm_data, mm_kwargs)
|
||||
else:
|
||||
hf_processor = self.info.get_hf_processor()
|
||||
tokenizer = hf_processor.tokenizer
|
||||
prompt = hf_processor._construct_prompts([prompt])[0]
|
||||
processed_outputs = tokenizer(prompt,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt")
|
||||
return processed_outputs
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_config = self.info.get_hf_config()
|
||||
pad_token_id = hf_config.pad_token_id
|
||||
bos_token_id = hf_config.bos_token_id
|
||||
num_image_tokens = self.info.get_max_image_tokens()
|
||||
image_tokens = [pad_token_id] * num_image_tokens
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[bos_token_id],
|
||||
replacement=PromptReplacementDetails(
|
||||
full=image_tokens + [bos_token_id],
|
||||
features=image_tokens,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
Florence2MultiModalProcessor,
|
||||
info=Florence2ProcessingInfo,
|
||||
dummy_inputs=Florence2DummyInputsBuilder)
|
||||
class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
processor_config = vllm_config.model_config.hf_image_processor_config
|
||||
|
||||
# TODO(Isotr0py): Add vision backbone
|
||||
self.config = config
|
||||
self.vision_config = config.vision_config
|
||||
self.processor_config = processor_config
|
||||
assert config.vision_config.model_type == 'davit', (
|
||||
'only DaViT is supported for now')
|
||||
self.vision_tower = DaViT.from_config(config=config.vision_config)
|
||||
self._build_image_projection_layers(config)
|
||||
self.language_model = Florence2LanguageForConditionalGeneration(
|
||||
vllm_config=vllm_config.with_hf_config(config.text_config),
|
||||
prefix=f"{prefix}.language_model",
|
||||
)
|
||||
self.pad_token_id = config.pad_token_id
|
||||
|
||||
@property
|
||||
def _build_image_projection_layers(self, config: PretrainedConfig):
|
||||
image_dim_out = config.vision_config.dim_embed[-1]
|
||||
dim_projection = config.vision_config.projection_dim
|
||||
self.image_projection = nn.Parameter(
|
||||
torch.empty(image_dim_out, dim_projection))
|
||||
self.image_proj_norm = nn.LayerNorm(dim_projection)
|
||||
image_pos_embed_config = config.vision_config.image_pos_embed
|
||||
if image_pos_embed_config['type'] == 'learned_abs_2d':
|
||||
self.image_pos_embed = LearnedAbsolutePositionEmbedding2D(
|
||||
embedding_dim=image_dim_out,
|
||||
num_pos=image_pos_embed_config['max_pos_embeddings'])
|
||||
else:
|
||||
raise NotImplementedError("Florence2 only supports learned_abs_2d "
|
||||
"as image position embedding.")
|
||||
|
||||
self.image_feature_source = config.vision_config.image_feature_source
|
||||
|
||||
# temporal embedding
|
||||
visual_temporal_embedding_config = (
|
||||
self.vision_config.visual_temporal_embedding)
|
||||
if visual_temporal_embedding_config['type'] == 'COSINE':
|
||||
self.visual_temporal_embed = PositionalEmbeddingCosine1D(
|
||||
embed_dim=image_dim_out,
|
||||
max_seq_len=visual_temporal_embedding_config[
|
||||
'max_temporal_embeddings'])
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'Florence2 only supports COSINE as temporal embedding.')
|
||||
|
||||
@cached_property
|
||||
def sampler(self):
|
||||
return self.language_model.sampler
|
||||
if hasattr(self.language_model, "sampler"):
|
||||
return self.language_model.sampler
|
||||
return get_sampler()
|
||||
|
||||
def _validate_pixel_values(
|
||||
self, data: Union[torch.Tensor, List[torch.Tensor]]
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
|
||||
size = self.processor_config["size"]
|
||||
h, w = size["height"], size["width"]
|
||||
expected_dims = (3, h, w)
|
||||
|
||||
def _validate_shape(d: torch.Tensor):
|
||||
actual_dims = tuple(d.shape)
|
||||
|
||||
if actual_dims != expected_dims:
|
||||
expected_expr = tuple(*map(str, expected_dims))
|
||||
raise ValueError(
|
||||
"The expected shape of pixel values per batch "
|
||||
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
|
||||
|
||||
for d in data:
|
||||
_validate_shape(d)
|
||||
|
||||
return data
|
||||
|
||||
def _parse_and_validate_image_input(self, **kwargs: object):
|
||||
pixel_values: Optional[Union[List[List[torch.Tensor]],
|
||||
List[torch.Tensor],
|
||||
torch.Tensor]] = kwargs.pop(
|
||||
"pixel_values", None)
|
||||
image_embeds: Optional[Union[List[List[torch.Tensor]],
|
||||
List[torch.Tensor],
|
||||
torch.Tensor]] = kwargs.pop(
|
||||
"image_embeds", None)
|
||||
|
||||
if pixel_values is None and image_embeds is None:
|
||||
return None
|
||||
|
||||
if pixel_values is not None and image_embeds is not None:
|
||||
raise ValueError(
|
||||
"Both pixel values and image embeds are provided.")
|
||||
|
||||
if pixel_values is not None:
|
||||
return Florence2ImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(
|
||||
flatten_bn(pixel_values, concat=True)),
|
||||
)
|
||||
|
||||
if image_embeds is not None:
|
||||
raise NotImplementedError
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _encode_image(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
dtype = next(self.vision_tower.parameters()).dtype
|
||||
pixel_values = pixel_values.to(dtype)
|
||||
|
||||
batch_size, T = pixel_values.size(0), 1
|
||||
x = self.vision_tower.forward_features_unpool(pixel_values)
|
||||
if self.image_pos_embed is not None:
|
||||
x = x.view(batch_size * T, -1, x.shape[-1])
|
||||
num_tokens = x.shape[-2]
|
||||
h, w = int(num_tokens**0.5), int(num_tokens**0.5)
|
||||
assert h * w == num_tokens, (
|
||||
'only support square feature maps for now')
|
||||
x = x.view(batch_size * T, h, w, x.shape[-1])
|
||||
pos_embed = self.image_pos_embed(x)
|
||||
x = x + pos_embed
|
||||
x = x.view(batch_size, T * h * w, x.shape[-1])
|
||||
|
||||
if self.visual_temporal_embed is not None:
|
||||
visual_temporal_embed = self.visual_temporal_embed(
|
||||
x.view(batch_size, T, -1, x.shape[-1])[:, :, 0])
|
||||
x = x.view(batch_size, T, -1,
|
||||
x.shape[-1]) + visual_temporal_embed.view(
|
||||
1, T, 1, x.shape[-1])
|
||||
|
||||
x_feat_dict = {}
|
||||
|
||||
spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2)
|
||||
x_feat_dict['spatial_avg_pool'] = spatial_avg_pool_x
|
||||
|
||||
temporal_avg_pool_x = x.view(batch_size, T, -1,
|
||||
x.shape[-1]).mean(dim=1)
|
||||
x_feat_dict['temporal_avg_pool'] = temporal_avg_pool_x
|
||||
|
||||
x = x.view(batch_size, T, -1, x.shape[-1])[:, -1]
|
||||
x_feat_dict['last_frame'] = x
|
||||
|
||||
new_x = []
|
||||
for _image_feature_source in self.image_feature_source:
|
||||
if _image_feature_source not in x_feat_dict:
|
||||
raise ValueError('invalid image feature source: {}'.format(
|
||||
_image_feature_source))
|
||||
new_x.append(x_feat_dict[_image_feature_source])
|
||||
|
||||
x = torch.cat(new_x, dim=1)
|
||||
|
||||
x = x @ self.image_projection
|
||||
x = self.image_proj_norm(x)
|
||||
|
||||
return x
|
||||
|
||||
def _process_image_input(
|
||||
self, image_input: Florence2ImagePixelInputs) -> torch.Tensor:
|
||||
assert image_input["type"] == "pixel_values"
|
||||
pixel_values = image_input["data"]
|
||||
return self._encode_image(pixel_values)
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> torch.Tensor:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return None
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
return vision_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[NestedTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
self.pad_token_id)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -216,8 +1082,19 @@ class Florence2ForConditionalGeneration(nn.Module):
|
||||
Returns:
|
||||
Output torch.Tensor
|
||||
"""
|
||||
return self.language_model(input_ids, positions, encoder_input_ids,
|
||||
encoder_positions)
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
if encoder_input_ids.numel() > 0 or vision_embeddings is not None:
|
||||
inputs_embeds = self.get_input_embeddings(encoder_input_ids,
|
||||
vision_embeddings)
|
||||
else:
|
||||
inputs_embeds = None
|
||||
|
||||
hidden_states = self.language_model(input_ids,
|
||||
positions,
|
||||
encoder_input_ids,
|
||||
encoder_positions,
|
||||
inputs_embeds=inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
@ -236,9 +1113,5 @@ class Florence2ForConditionalGeneration(nn.Module):
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
skip_prefixes = [
|
||||
'image_projection', "vision_tower", "image_proj_norm",
|
||||
"image_pos_embed", "visual_temporal_embed"
|
||||
]
|
||||
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
@ -105,7 +105,6 @@ _TEXT_GENERATION_MODELS = {
|
||||
# [Encoder-decoder]
|
||||
"BartModel": ("bart", "BartForConditionalGeneration"),
|
||||
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
|
||||
"Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501
|
||||
}
|
||||
|
||||
_EMBEDDING_MODELS = {
|
||||
@ -182,6 +181,7 @@ _MULTIMODAL_MODELS = {
|
||||
"Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501
|
||||
"UltravoxModel": ("ultravox", "UltravoxModel"),
|
||||
# [Encoder-decoder]
|
||||
"Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501
|
||||
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
|
||||
"WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501
|
||||
}
|
||||
|
||||
@ -1303,6 +1303,14 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def create_decoder_prompt(
|
||||
self,
|
||||
prompt: Union[str, list[int]],
|
||||
mm_data: MultiModalDataDict,
|
||||
) -> Union[str, list[int]]:
|
||||
"""Create input prompt for the decoder."""
|
||||
return prompt
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt: Union[str, list[int]],
|
||||
@ -1323,17 +1331,15 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
hf_processor_mm_kwargs,
|
||||
)
|
||||
|
||||
# We assumed the decoder prompt text is copied from
|
||||
# the original encoder prompt without extra process
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
if isinstance(prompt, str):
|
||||
decoder_prompt = prompt
|
||||
decoder_prompt = self.create_decoder_prompt(prompt, mm_data)
|
||||
if isinstance(decoder_prompt, str):
|
||||
decoder_prompt_ids = encode_tokens(tokenizer,
|
||||
prompt,
|
||||
decoder_prompt,
|
||||
add_special_tokens=False)
|
||||
else:
|
||||
decoder_prompt = decode_tokens(tokenizer, prompt)
|
||||
decoder_prompt_ids = prompt
|
||||
decoder_prompt_ids = decoder_prompt
|
||||
decoder_prompt = decode_tokens(tokenizer, decoder_prompt)
|
||||
|
||||
mm_inputs = MultiModalEncDecInputs(
|
||||
encoder_prompt=encoder_inputs["prompt"],
|
||||
|
||||
@ -204,9 +204,11 @@ class MultiModalProfiler(Generic[_I]):
|
||||
"and/or reduce `mm_counts`.", seq_len, total_len,
|
||||
total_placeholders_by_modality)
|
||||
|
||||
num_tokens_to_pad = max(total_len, seq_len) - total_len
|
||||
prompt_token_ids.extend([0] * num_tokens_to_pad)
|
||||
|
||||
return DummyData(
|
||||
seq_data=SequenceData.from_prompt_token_counts(
|
||||
(0, max(seq_len, total_len))),
|
||||
seq_data=SequenceData.from_seqs(prompt_token_ids),
|
||||
multi_modal_data=None,
|
||||
multi_modal_placeholders=None,
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user