[VLM] Support multimodal inputs for Florence-2 models (#13320)

This commit is contained in:
Isotr0py 2025-02-27 18:06:41 +08:00 committed by GitHub
parent 788f284b53
commit edf309ebbe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 1075 additions and 114 deletions

View File

@ -715,6 +715,13 @@ See [this page](#generative-models) for more information on how to use generativ
*
* ✅︎
* ✅︎
- * `Florence2ForConditionalGeneration`
* Florence-2
* T + I
* `microsoft/Florence-2-base`, `microsoft/Florence-2-large` etc.
*
*
*
- * `FuyuForCausalLM`
* Fuyu
* T + I

View File

@ -1,34 +1,45 @@
# SPDX-License-Identifier: Apache-2.0
'''
"""
Demonstrate prompting of text-to-text
encoder/decoder models, specifically Florence-2
'''
"""
# TODO(Isotr0py):
# Move to offline_inference/vision_language.py
# after porting vision backbone
from vllm import LLM, SamplingParams
dtype = "float"
from vllm.assets.image import ImageAsset
# Create a Florence-2 encoder/decoder model instance
llm = LLM(
model="microsoft/Florence-2-base",
tokenizer="facebook/bart-base",
dtype=dtype,
model="microsoft/Florence-2-large",
tokenizer="facebook/bart-large",
max_num_seqs=8,
trust_remote_code=True,
)
prompts = [
"<CAPTION>", "<DETAILED_CAPTION>", "<MORE_DETAILED_CAPTION>",
"<CAPTION_TO_PHRASE_GROUNDING>", "<OD>", "<DENSE_REGION_CAPTION>",
"<REGION_PROPOSAL>", "<OCR>", "<OCR_WITH_REGION>"
{ # implicit prompt with task token
"prompt": "<DETAILED_CAPTION>",
"multi_modal_data": {
"image": ImageAsset("stop_sign").pil_image
},
},
{ # explicit encoder/decoder prompt
"encoder_prompt": {
"prompt": "Describe in detail what is shown in the image.",
"multi_modal_data": {
"image": ImageAsset("cherry_blossom").pil_image
},
},
"decoder_prompt": "",
},
]
# Create a sampling params object.
sampling_params = SamplingParams(
temperature=0,
top_p=1.0,
min_tokens=0,
max_tokens=20,
max_tokens=128,
)
# Generate output tokens from the prompts. The output is a list of
@ -38,9 +49,5 @@ outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
encoder_prompt = output.encoder_prompt
generated_text = output.outputs[0].text
print(f"Encoder prompt: {encoder_prompt!r}, "
f"Decoder prompt: {prompt!r}, "
f"Generated text: {generated_text!r}")
print(f"Generated text: {generated_text!r}")

View File

@ -82,6 +82,22 @@ def run_deepseek_vl2(question: str, modality: str):
return llm, prompt, stop_token_ids
# Florence2
def run_florence2(question: str, modality: str):
assert modality == "image"
llm = LLM(model="microsoft/Florence-2-large",
tokenizer="facebook/bart-large",
max_num_seqs=8,
trust_remote_code=True,
dtype="bfloat16",
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
prompt = "<MORE_DETAILED_CAPTION>"
stop_token_ids = None
return llm, prompt, stop_token_ids
# Fuyu
def run_fuyu(question: str, modality: str):
assert modality == "image"
@ -571,6 +587,7 @@ model_example_map = {
"blip-2": run_blip2,
"chameleon": run_chameleon,
"deepseek_vl_v2": run_deepseek_vl2,
"florence2": run_florence2,
"fuyu": run_fuyu,
"glm4v": run_glm4v,
"h2ovl_chat": run_h2ovl,

View File

@ -600,8 +600,8 @@ class HfRunner:
if images is not None and images[i] is not None:
processor_kwargs["images"] = images[i]
encoder_input_ids = self.wrap_device(
self.processor(**processor_kwargs).input_ids,
encoder_inputs = self.wrap_device(
self.processor(**processor_kwargs),
device=self.model.device.type,
)
@ -615,13 +615,13 @@ class HfRunner:
)
output = self.model.generate(
encoder_input_ids,
decoder_input_ids=decoder_input_ids,
use_cache=True,
do_sample=False,
max_new_tokens=max_tokens,
output_hidden_states=True,
return_dict_in_generate=True,
**encoder_inputs,
**kwargs,
)

View File

@ -15,7 +15,7 @@ from ....conftest import HfRunner, VllmRunner
from ....utils import RemoteOpenAIServer
from ...utils import check_logprobs_close
MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
MODEL_NAME = "fixie-ai/ultravox-v0_4"
AudioTuple = Tuple[np.ndarray, int]
@ -187,7 +187,7 @@ def run_multi_audio_test(
@pytest.mark.core_model
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("vllm_kwargs", [

View File

@ -1,52 +1,59 @@
# SPDX-License-Identifier: Apache-2.0
from functools import partial
from typing import List, Optional, Tuple, Type
from typing import Optional, Type
import pytest
from PIL import Image
from vllm.inputs.data import ExplicitEncoderDecoderPrompt
from vllm.inputs.data import ExplicitEncoderDecoderPrompt, TextPrompt
from vllm.multimodal.image import rescale_image_size
from vllm.sequence import SampleLogprobs
from ....conftest import HfRunner, VllmRunner
from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from ...utils import check_logprobs_close
Florence2Prompt = partial(ExplicitEncoderDecoderPrompt,
decoder_prompt=None,
mm_processor_kwargs=None)
MODELS = ["microsoft/Florence-2-base"]
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
# Therefore, we borrow the BartTokenizer from the original Bart model
TOKENIZER = "facebook/bart-base"
PROMPTS = [
Florence2Prompt(encoder_prompt="<CAPTION>"),
Florence2Prompt(encoder_prompt="<DETAILED_CAPTION>"),
Florence2Prompt(encoder_prompt="<MORE_DETAILED_CAPTION>"),
Florence2Prompt(encoder_prompt="<CAPTION_TO_PHRASE_GROUNDING>"),
Florence2Prompt(encoder_prompt="<DENSE_REGION_CAPTION>"),
Florence2Prompt(encoder_prompt="<REGION_PROPOSAL>"),
Florence2Prompt(encoder_prompt="<OCR_WITH_REGION>"),
Florence2Prompt(encoder_prompt="<OCR>"),
Florence2Prompt(encoder_prompt="<OD>"),
]
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"<CAPTION>", # special task token
"cherry_blossom":
"Describe in detail what is shown in the image.",
})
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Optional[SampleLogprobs]], ):
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output
def get_hf_images_prompts(
prompts_: list[ExplicitEncoderDecoderPrompt[str, TextPrompt]],
) -> tuple[list[ExplicitEncoderDecoderPrompt[str, str]], list[Image.Image]]:
prompts, images = [], []
for prompt in prompts_:
encoder_prompt = prompt["encoder_prompt"]
prompts.append(
ExplicitEncoderDecoderPrompt(
encoder_prompt=encoder_prompt["prompt"],
decoder_prompt=None,
))
images.append(encoder_prompt["multi_modal_data"]["image"])
return prompts, images
hf_output_str = "</s><s>" + output_str + "</s>"
return output_ids, hf_output_str, out_logprobs
def hf_to_vllm_output(hf_output: tuple[list[int], str,
Optional[SampleLogprobs]]):
"""Sanitize hf output to be comparable with vllm output."""
output_ids, output_str, out_logprobs = hf_output
output_str = output_str.replace("</s>", "").replace("<s>", "")
output_ids = [ids for ids in output_ids if ids not in [0, 2]]
return output_ids, output_str, out_logprobs
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
prompts: List[ExplicitEncoderDecoderPrompt],
inputs: list[list[ExplicitEncoderDecoderPrompt]],
model: str,
*,
dtype: str,
@ -56,46 +63,76 @@ def run_test(
distributed_executor_backend: Optional[str] = None,
) -> None:
with vllm_runner(model,
max_num_seqs=8,
tokenizer_name=TOKENIZER,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
prompts, max_tokens, num_logprobs)
vllm_outputs_per_case = [
vllm_model.generate_encoder_decoder_greedy_logprobs(
prompts, max_tokens, num_logprobs=num_logprobs)
for prompts in inputs
]
hf_inputs = [get_hf_images_prompts(prompts) for prompts in inputs]
# Florence-2 processors require image inputs
dummy_image = Image.new(mode="RGB", size=(2, 2))
with hf_runner(model, dtype=dtype, skip_tokenizer_init=True) as hf_model:
hf_model.model.get_output_embeddings = lambda: \
hf_model.model.language_model.lm_head
hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit(
prompts,
max_tokens,
num_logprobs,
images=[dummy_image] * len(prompts),
))
hf_outputs_per_case = [
hf_model.generate_encoder_decoder_greedy_logprobs_limit(
prompts, max_tokens, num_logprobs=num_logprobs, images=images)
for prompts, images in hf_inputs
]
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output) for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
)
for hf_outputs, vllm_outputs in zip(hf_outputs_per_case,
vllm_outputs_per_case):
check_logprobs_close(
outputs_0_lst=[hf_to_vllm_output(output) for output in hf_outputs],
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
@pytest.mark.core_model
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float", "bfloat16"])
@pytest.mark.parametrize(
"size_factors",
[
# No image
[],
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, model, dtype, max_tokens,
num_logprobs) -> None:
def test_models(hf_runner: Type[HfRunner], vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets, model: str,
size_factors: list[int], dtype: str, max_tokens: int,
num_logprobs: int) -> None:
images = [asset.pil_image for asset in image_assets]
inputs_per_image = [[
ExplicitEncoderDecoderPrompt(
encoder_prompt=TextPrompt(
prompt=prompt,
multi_modal_data={"image": rescale_image_size(image, factor)}),
decoder_prompt=None,
) for factor in size_factors
] for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
run_test(
hf_runner,
vllm_runner,
PROMPTS,
inputs_per_image,
model,
dtype=dtype,
max_tokens=max_tokens,

View File

@ -29,8 +29,8 @@ def _test_processing_correctness(
model_config = ModelConfig(
model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
tokenizer=model_info.tokenizer or model_id,
tokenizer_mode=model_info.tokenizer_mode,
trust_remote_code=model_info.trust_remote_code,
seed=0,
dtype="float16",
@ -151,6 +151,7 @@ def _test_processing_correctness(
"Salesforce/blip2-opt-2.7b",
"facebook/chameleon-7b",
"deepseek-ai/deepseek-vl2-tiny",
"microsoft/Florence-2-base",
"adept/fuyu-8b",
"THUDM/glm-4v-9b",
"h2oai/h2ovl-mississippi-800m",

View File

@ -193,11 +193,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
# [Encoder-decoder]
"BartModel": _HfExamplesInfo("facebook/bart-base"),
"BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"),
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
# Therefore, we borrow the BartTokenizer from the original Bart model
"Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501
tokenizer="facebook/bart-base",
trust_remote_code=True), # noqa: E501
}
_EMBEDDING_EXAMPLE_MODELS = {
@ -288,6 +283,11 @@ _MULTIMODAL_EXAMPLE_MODELS = {
extras={"v0.5": "fixie-ai/ultravox-v0_5-llama-3_2-1b"}, # noqa: E501
trust_remote_code=True),
# [Encoder-decoder]
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
# Therefore, we borrow the BartTokenizer from the original Bart model
"Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501
tokenizer="facebook/bart-base",
trust_remote_code=True), # noqa: E501
"MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501
"WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501
}

View File

@ -588,8 +588,12 @@ class BartEncoder(nn.Module):
self.layernorm_embedding = nn.LayerNorm(embed_dim)
def forward(self, input_ids: torch.Tensor,
positions: torch.Tensor) -> torch.Tensor:
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""
Args:
input_ids
@ -602,7 +606,8 @@ class BartEncoder(nn.Module):
Decoder output torch.Tensor
"""
# retrieve input_ids and inputs_embeds
inputs_embeds = self.embed_tokens(input_ids)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
embed_pos = self.embed_positions(positions)
embed_pos = embed_pos.to(inputs_embeds.device)
@ -661,9 +666,13 @@ class BartDecoder(nn.Module):
self.layernorm_embedding = nn.LayerNorm(config.d_model)
def forward(self, decoder_input_ids: torch.Tensor,
decoder_positions: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor]) -> torch.Tensor:
def forward(
self,
decoder_input_ids: torch.Tensor,
decoder_positions: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor],
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""
Args:
decoder_input_ids
@ -677,8 +686,10 @@ class BartDecoder(nn.Module):
Returns:
Decoder output torch.Tensor
"""
inputs_embeds = self.embed_tokens(decoder_input_ids)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(decoder_input_ids)
else:
decoder_positions = inputs_embeds[:, -1]
# embed positions
embed_pos = self.embed_positions(decoder_positions)

View File

@ -1,10 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
import math
from typing import Iterable, Optional, Set, Tuple
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, OrderedDict,
Set, Tuple, TypedDict, Union)
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import BatchFeature, PretrainedConfig
from vllm.config import VllmConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@ -14,11 +19,567 @@ from vllm.model_executor.models.bart import (BartDecoder, BartEncoder,
BartParallelLMHead,
BartScaledWordEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.parse import MultiModalDataDict, MultiModalDataItems
from vllm.multimodal.processing import (BaseProcessingInfo,
EncDecMultiModalProcessor,
PromptReplacement,
PromptReplacementDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from .utils import AutoWeightsLoader
from .interfaces import SupportsMultiModal
from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings
class Florence2ImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: (batch_size, num_channel, height, width)"""
# ViT implementation are all copied from
# https://huggingface.co/microsoft/Florence-2-base/blob/main/modeling_florence2.py
class LearnedAbsolutePositionEmbedding2D(nn.Module):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, embedding_dim=256, num_pos=50):
super().__init__()
self.row_embeddings = nn.Embedding(num_pos, embedding_dim // 2)
self.column_embeddings = nn.Embedding(
num_pos, embedding_dim - (embedding_dim // 2))
def forward(self, pixel_values):
"""
pixel_values: (batch_size, height, width, num_channels)
returns: (batch_size, height, width, embedding_dim * 2)
"""
if len(pixel_values.shape) != 4:
raise ValueError('pixel_values must be a 4D tensor')
height, width = pixel_values.shape[1:3]
width_values = torch.arange(width, device=pixel_values.device)
height_values = torch.arange(height, device=pixel_values.device)
x_emb = self.column_embeddings(width_values)
y_emb = self.row_embeddings(height_values)
# (height, width, embedding_dim * 2)
pos = torch.cat([
x_emb.unsqueeze(0).repeat(height, 1, 1),
y_emb.unsqueeze(1).repeat(1, width, 1)
],
dim=-1)
# (embedding_dim * 2, height, width)
pos = pos.permute(2, 0, 1)
pos = pos.unsqueeze(0)
# (batch_size, embedding_dim * 2, height, width)
pos = pos.repeat(pixel_values.shape[0], 1, 1, 1)
# (batch_size, height, width, embedding_dim * 2)
pos = pos.permute(0, 2, 3, 1)
return pos
class PositionalEmbeddingCosine1D(nn.Module):
"""
This class implements a very simple positional encoding. It follows closely
the encoder from the link below:
https://pytorch.org/tutorials/beginner/translation_transformer.html
Args:
embed_dim: The dimension of the embeddings.
dropout_prob: The dropout probability.
max_seq_len: The maximum length to precompute the positional encodings.
"""
def __init__(self, embed_dim: int = 512, max_seq_len: int = 1024) -> None:
super().__init__()
self.embed_dim = embed_dim
self.max_seq_len = max_seq_len
# Generate the sinusoidal arrays.
factor = math.log(10000)
denominator = torch.exp(-factor * torch.arange(0, self.embed_dim, 2) /
self.embed_dim)
# Matrix where rows correspond to a positional embedding as a function
# of the position index (i.e., the row index).
frequencies = \
torch.arange(0, self.max_seq_len) \
.reshape(self.max_seq_len, 1) * denominator
pos_idx_to_embed = torch.zeros((self.max_seq_len, self.embed_dim))
# Populate uneven entries.
pos_idx_to_embed[:, 0::2] = torch.sin(frequencies)
pos_idx_to_embed[:, 1::2] = torch.cos(frequencies)
# Save the positional embeddings in a constant buffer.
# self.register_buffer("pos_idx_to_embed", pos_idx_to_embed)
self.pos_idx_to_embed = nn.Parameter(pos_idx_to_embed,
requires_grad=False)
def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor:
"""
Args:
seq_embeds: The sequence embeddings in order. Allowed size:
1. [T, D], where T is the length of the sequence, and D is the
frame embedding dimension.
2. [B, T, D], where B is the batch size and T and D are the
same as above.
Returns a tensor of with the same dimensions as the input: i.e.,
[1, T, D] or [T, D].
"""
shape_len = len(seq_embeds.shape)
assert 2 <= shape_len <= 3
len_seq = seq_embeds.size(-2)
assert len_seq <= self.max_seq_len
pos_embeds = self.pos_idx_to_embed[0:seq_embeds.size(-2), :]
# Adapt pre-computed positional embeddings to the input.
if shape_len == 3:
pos_embeds = pos_embeds.view(
(1, pos_embeds.size(0), pos_embeds.size(1)))
return pos_embeds
class MySequential(nn.Sequential):
def forward(self, *inputs):
for module in self._modules.values():
if isinstance(inputs, tuple):
inputs = module(*inputs)
else:
inputs = module(inputs)
return inputs
class PreNorm(nn.Module):
def __init__(self, norm, fn):
super().__init__()
self.norm = norm
self.fn = fn
def forward(self, x, *args, **kwargs):
shortcut = x
if self.norm is not None:
x, size = self.fn(self.norm(x), *args, **kwargs)
else:
x, size = self.fn(x, *args, **kwargs)
x = shortcut + x
return x, size
class Mlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.net = nn.Sequential(
OrderedDict([("fc1", nn.Linear(in_features, hidden_features)),
("act", act_layer()),
("fc2", nn.Linear(hidden_features, out_features))]))
def forward(self, x, size):
return self.net(x), size
class DepthWiseConv2d(nn.Module):
def __init__(
self,
dim_in,
kernel_size,
padding,
stride,
bias=True,
):
super().__init__()
self.dw = nn.Conv2d(dim_in,
dim_in,
kernel_size=kernel_size,
padding=padding,
groups=dim_in,
stride=stride,
bias=bias)
def forward(self, x, size):
B, N, C = x.shape
H, W = size
assert N == H * W
x = self.dw(x.transpose(1, 2).view(B, C, H, W))
size = (x.size(-2), x.size(-1))
x = x.flatten(2).transpose(1, 2)
return x, size
class ConvEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self,
patch_size=7,
in_chans=3,
embed_dim=64,
stride=4,
padding=2,
norm_layer=None,
pre_norm=True):
super().__init__()
self.patch_size = patch_size
self.proj = nn.Conv2d(in_chans,
embed_dim,
kernel_size=patch_size,
stride=stride,
padding=padding)
dim_norm = in_chans if pre_norm else embed_dim
self.norm = norm_layer(dim_norm) if norm_layer else None
self.pre_norm = pre_norm
def forward(self, x, size):
H, W = size
if len(x.size()) == 3:
if self.norm and self.pre_norm:
x = self.norm(x)
x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W)
x = self.proj(x)
_, _, H, W = x.shape
x = rearrange(x, 'b c h w -> b (h w) c')
if self.norm and not self.pre_norm:
x = self.norm(x)
return x, (H, W)
class ChannelAttention(nn.Module):
def __init__(self, dim, groups=8, qkv_bias=True):
super().__init__()
self.groups = groups
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
def forward(self, x, size):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.groups,
C // self.groups).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * (float(N)**-0.5)
attention = q.transpose(-1, -2) @ k
attention = attention.softmax(dim=-1)
x = (attention @ v.transpose(-1, -2)).transpose(-1, -2)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x, size
class ChannelBlock(nn.Module):
def __init__(self,
dim,
groups,
mlp_ratio=4.,
qkv_bias=True,
drop_path_rate=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
conv_at_attn=True,
conv_at_ffn=True):
super().__init__()
self.conv1 = PreNorm(None, DepthWiseConv2d(
dim, 3, 1, 1)) if conv_at_attn else None
self.channel_attn = PreNorm(
norm_layer(dim),
ChannelAttention(dim, groups=groups, qkv_bias=qkv_bias),
)
self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1,
1)) if conv_at_ffn else None
self.ffn = PreNorm(
norm_layer(dim),
Mlp(in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer),
)
def forward(self, x, size):
if self.conv1:
x, size = self.conv1(x, size)
x, size = self.channel_attn(x, size)
if self.conv2:
x, size = self.conv2(x, size)
x, size = self.ffn(x, size)
return x, size
def window_partition(x, window_size: int):
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size,
C)
windows = x.permute(0, 1, 3, 2, 4,
5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, batch_size: int, window_size: int, H: int, W: int):
B = batch_size
x = windows.view(B, H // window_size, W // window_size, window_size,
window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
def __init__(self, dim, num_heads, window_size, qkv_bias=True):
super().__init__()
self.dim = dim
self.window_size = window_size
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = float(head_dim)**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, size):
H, W = size
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
x = x.view(B, H, W, C)
pad_l = pad_t = 0
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape
x = window_partition(x, self.window_size)
x = x.view(-1, self.window_size * self.window_size, C)
# W-MSA/SW-MSA
# attn_windows = self.attn(x_windows)
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
attn = self.softmax(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
# merge windows
x = x.view(-1, self.window_size, self.window_size, C)
x = window_reverse(x, B, self.window_size, Hp, Wp)
if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
return x, size
class SpatialBlock(nn.Module):
def __init__(self,
dim,
num_heads,
window_size,
mlp_ratio=4.,
qkv_bias=True,
drop_path_rate=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
conv_at_attn=True,
conv_at_ffn=True):
super().__init__()
self.conv1 = PreNorm(None, DepthWiseConv2d(
dim, 3, 1, 1)) if conv_at_attn else None
self.window_attn = PreNorm(
norm_layer(dim),
WindowAttention(dim, num_heads, window_size, qkv_bias=qkv_bias),
)
self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1,
1)) if conv_at_ffn else None
self.ffn = PreNorm(
norm_layer(dim),
Mlp(in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer),
)
def forward(self, x, size):
if self.conv1:
x, size = self.conv1(x, size)
x, size = self.window_attn(x, size)
if self.conv2:
x, size = self.conv2(x, size)
x, size = self.ffn(x, size)
return x, size
class DaViT(nn.Module):
def __init__(
self,
in_chans=3,
num_classes=1000,
depths=(1, 1, 3, 1),
patch_size=(7, 2, 2, 2),
patch_stride=(4, 2, 2, 2),
patch_padding=(3, 0, 0, 0),
patch_prenorm=(False, False, False, False),
embed_dims=(64, 128, 192, 256),
num_heads=(3, 6, 12, 24),
num_groups=(3, 6, 12, 24),
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
drop_path_rate=0.1,
norm_layer=nn.LayerNorm,
enable_checkpoint=False,
conv_at_attn=True,
conv_at_ffn=True,
):
super().__init__()
self.num_classes = num_classes
self.embed_dims = embed_dims
self.num_heads = num_heads
self.num_groups = num_groups
self.num_stages = len(self.embed_dims)
self.enable_checkpoint = enable_checkpoint
assert self.num_stages == len(self.num_heads) == len(self.num_groups)
num_stages = len(embed_dims)
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate,
sum(depths) * 2)
]
depth_offset = 0
convs = []
blocks = []
for i in range(num_stages):
conv_embed = ConvEmbed(
patch_size=patch_size[i],
stride=patch_stride[i],
padding=patch_padding[i],
in_chans=in_chans if i == 0 else self.embed_dims[i - 1],
embed_dim=self.embed_dims[i],
norm_layer=norm_layer,
pre_norm=patch_prenorm[i])
convs.append(conv_embed)
block = MySequential(*[
MySequential(
OrderedDict([('spatial_block',
SpatialBlock(
embed_dims[i],
num_heads[i],
window_size,
drop_path_rate=dpr[depth_offset + j * 2],
qkv_bias=qkv_bias,
mlp_ratio=mlp_ratio,
conv_at_attn=conv_at_attn,
conv_at_ffn=conv_at_ffn,
)),
('channel_block',
ChannelBlock(
embed_dims[i],
num_groups[i],
drop_path_rate=dpr[depth_offset + j * 2 +
1],
qkv_bias=qkv_bias,
mlp_ratio=mlp_ratio,
conv_at_attn=conv_at_attn,
conv_at_ffn=conv_at_ffn,
))])) for j in range(depths[i])
])
blocks.append(block)
depth_offset += depths[i] * 2
self.convs = nn.ModuleList(convs)
self.blocks = nn.ModuleList(blocks)
self.avgpool = nn.AdaptiveAvgPool1d(1)
@property
def dim_out(self):
return self.embed_dims[-1]
def forward_features_unpool(self, x):
"""
forward until avg pooling
Args:
x (_type_): input image tensor
"""
input_size = (x.size(2), x.size(3))
for conv, block in zip(self.convs, self.blocks):
x, input_size = conv(x, input_size)
x, input_size = block(x, input_size)
return x
def forward_features(self, x):
x = self.forward_features_unpool(x)
# (batch_size, num_tokens, token_dim)
x = self.avgpool(x.transpose(1, 2))
# (batch_size, 1, num_tokens)
x = torch.flatten(x, 1)
x = self.norms(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
@classmethod
def from_config(cls, config):
return cls(
depths=config.depths,
embed_dims=config.dim_embed,
num_heads=config.num_heads,
num_groups=config.num_groups,
patch_size=config.patch_size,
patch_stride=config.patch_stride,
patch_padding=config.patch_padding,
patch_prenorm=config.patch_prenorm,
drop_path_rate=config.drop_path_rate,
window_size=config.window_size,
)
# Language backbone and processor implementation
class Florence2LanguageModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@ -47,9 +608,14 @@ class Florence2LanguageModel(nn.Module):
self.encoder.embed_tokens.weight = self.shared.weight
self.decoder.embed_tokens.weight = self.shared.weight
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
encoder_input_ids: torch.Tensor,
encoder_positions: torch.Tensor) -> torch.Tensor:
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
encoder_input_ids: torch.Tensor,
encoder_positions: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""
Args:
input_ids
@ -68,11 +634,12 @@ class Florence2LanguageModel(nn.Module):
encoder_hidden_states = None
if encoder_input_ids.numel() > 0:
if inputs_embeds is not None or encoder_input_ids.numel() > 0:
# Run encoder attention if a non-zero number of encoder tokens
# are provided as input
encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
positions=encoder_positions)
positions=encoder_positions,
inputs_embeds=inputs_embeds)
# decoder outputs consists of
# (dec_features, past_key_value, dec_hidden, dec_attn)
@ -112,6 +679,7 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
positions: torch.Tensor,
encoder_input_ids: torch.Tensor,
encoder_positions: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
r"""
@ -127,8 +695,15 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
Returns:
Output torch.Tensor
"""
return self.model(input_ids, positions, encoder_input_ids,
encoder_positions)
return self.model(input_ids,
positions,
encoder_input_ids,
encoder_positions,
inputs_embeds=inputs_embeds)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.encoder.embed_tokens(input_ids)
def compute_logits(
self,
@ -177,21 +752,312 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
return loaded_params
class Florence2ForConditionalGeneration(nn.Module):
class Florence2ProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config()
def get_hf_processor(self):
return self.ctx.get_hf_processor()
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
def get_max_image_tokens(self) -> int:
processor_config = self.ctx.get_hf_image_processor_config()
return processor_config["image_seq_length"]
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
class Florence2DummyInputsBuilder(
BaseDummyInputsBuilder[Florence2ProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)
target_width = target_height = self.info.get_hf_config().projection_dim
mm_data = {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
return ProcessorInputs(
prompt_text="",
mm_data=mm_data,
)
class Florence2MultiModalProcessor(
EncDecMultiModalProcessor[Florence2ProcessingInfo]):
def _hf_processor_applies_repl(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> bool:
return False
def create_encoder_prompt(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
) -> Union[str, list[int]]:
return prompt
def create_decoder_prompt(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
) -> Union[str, list[int]]:
return [self.info.get_hf_config().eos_token_id]
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
if mm_data:
processed_outputs = super()._call_hf_processor(
prompt, mm_data, mm_kwargs)
else:
hf_processor = self.info.get_hf_processor()
tokenizer = hf_processor.tokenizer
prompt = hf_processor._construct_prompts([prompt])[0]
processed_outputs = tokenizer(prompt,
add_special_tokens=True,
return_tensors="pt")
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self.info.get_hf_config()
pad_token_id = hf_config.pad_token_id
bos_token_id = hf_config.bos_token_id
num_image_tokens = self.info.get_max_image_tokens()
image_tokens = [pad_token_id] * num_image_tokens
return [
PromptReplacement(
modality="image",
target=[bos_token_id],
replacement=PromptReplacementDetails(
full=image_tokens + [bos_token_id],
features=image_tokens,
),
)
]
@MULTIMODAL_REGISTRY.register_processor(
Florence2MultiModalProcessor,
info=Florence2ProcessingInfo,
dummy_inputs=Florence2DummyInputsBuilder)
class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
processor_config = vllm_config.model_config.hf_image_processor_config
# TODO(Isotr0py): Add vision backbone
self.config = config
self.vision_config = config.vision_config
self.processor_config = processor_config
assert config.vision_config.model_type == 'davit', (
'only DaViT is supported for now')
self.vision_tower = DaViT.from_config(config=config.vision_config)
self._build_image_projection_layers(config)
self.language_model = Florence2LanguageForConditionalGeneration(
vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=f"{prefix}.language_model",
)
self.pad_token_id = config.pad_token_id
@property
def _build_image_projection_layers(self, config: PretrainedConfig):
image_dim_out = config.vision_config.dim_embed[-1]
dim_projection = config.vision_config.projection_dim
self.image_projection = nn.Parameter(
torch.empty(image_dim_out, dim_projection))
self.image_proj_norm = nn.LayerNorm(dim_projection)
image_pos_embed_config = config.vision_config.image_pos_embed
if image_pos_embed_config['type'] == 'learned_abs_2d':
self.image_pos_embed = LearnedAbsolutePositionEmbedding2D(
embedding_dim=image_dim_out,
num_pos=image_pos_embed_config['max_pos_embeddings'])
else:
raise NotImplementedError("Florence2 only supports learned_abs_2d "
"as image position embedding.")
self.image_feature_source = config.vision_config.image_feature_source
# temporal embedding
visual_temporal_embedding_config = (
self.vision_config.visual_temporal_embedding)
if visual_temporal_embedding_config['type'] == 'COSINE':
self.visual_temporal_embed = PositionalEmbeddingCosine1D(
embed_dim=image_dim_out,
max_seq_len=visual_temporal_embedding_config[
'max_temporal_embeddings'])
else:
raise NotImplementedError(
'Florence2 only supports COSINE as temporal embedding.')
@cached_property
def sampler(self):
return self.language_model.sampler
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return get_sampler()
def _validate_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
size = self.processor_config["size"]
h, w = size["height"], size["width"]
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape)
if actual_dims != expected_dims:
expected_expr = tuple(*map(str, expected_dims))
raise ValueError(
"The expected shape of pixel values per batch "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input(self, **kwargs: object):
pixel_values: Optional[Union[List[List[torch.Tensor]],
List[torch.Tensor],
torch.Tensor]] = kwargs.pop(
"pixel_values", None)
image_embeds: Optional[Union[List[List[torch.Tensor]],
List[torch.Tensor],
torch.Tensor]] = kwargs.pop(
"image_embeds", None)
if pixel_values is None and image_embeds is None:
return None
if pixel_values is not None and image_embeds is not None:
raise ValueError(
"Both pixel values and image embeds are provided.")
if pixel_values is not None:
return Florence2ImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(
flatten_bn(pixel_values, concat=True)),
)
if image_embeds is not None:
raise NotImplementedError
raise AssertionError("This line should be unreachable.")
def _encode_image(self, pixel_values: torch.Tensor) -> torch.Tensor:
dtype = next(self.vision_tower.parameters()).dtype
pixel_values = pixel_values.to(dtype)
batch_size, T = pixel_values.size(0), 1
x = self.vision_tower.forward_features_unpool(pixel_values)
if self.image_pos_embed is not None:
x = x.view(batch_size * T, -1, x.shape[-1])
num_tokens = x.shape[-2]
h, w = int(num_tokens**0.5), int(num_tokens**0.5)
assert h * w == num_tokens, (
'only support square feature maps for now')
x = x.view(batch_size * T, h, w, x.shape[-1])
pos_embed = self.image_pos_embed(x)
x = x + pos_embed
x = x.view(batch_size, T * h * w, x.shape[-1])
if self.visual_temporal_embed is not None:
visual_temporal_embed = self.visual_temporal_embed(
x.view(batch_size, T, -1, x.shape[-1])[:, :, 0])
x = x.view(batch_size, T, -1,
x.shape[-1]) + visual_temporal_embed.view(
1, T, 1, x.shape[-1])
x_feat_dict = {}
spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2)
x_feat_dict['spatial_avg_pool'] = spatial_avg_pool_x
temporal_avg_pool_x = x.view(batch_size, T, -1,
x.shape[-1]).mean(dim=1)
x_feat_dict['temporal_avg_pool'] = temporal_avg_pool_x
x = x.view(batch_size, T, -1, x.shape[-1])[:, -1]
x_feat_dict['last_frame'] = x
new_x = []
for _image_feature_source in self.image_feature_source:
if _image_feature_source not in x_feat_dict:
raise ValueError('invalid image feature source: {}'.format(
_image_feature_source))
new_x.append(x_feat_dict[_image_feature_source])
x = torch.cat(new_x, dim=1)
x = x @ self.image_projection
x = self.image_proj_norm(x)
return x
def _process_image_input(
self, image_input: Florence2ImagePixelInputs) -> torch.Tensor:
assert image_input["type"] == "pixel_values"
pixel_values = image_input["data"]
return self._encode_image(pixel_values)
def get_multimodal_embeddings(self, **kwargs: object) -> torch.Tensor:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.pad_token_id)
return inputs_embeds
def forward(
self,
@ -216,8 +1082,19 @@ class Florence2ForConditionalGeneration(nn.Module):
Returns:
Output torch.Tensor
"""
return self.language_model(input_ids, positions, encoder_input_ids,
encoder_positions)
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
if encoder_input_ids.numel() > 0 or vision_embeddings is not None:
inputs_embeds = self.get_input_embeddings(encoder_input_ids,
vision_embeddings)
else:
inputs_embeds = None
hidden_states = self.language_model(input_ids,
positions,
encoder_input_ids,
encoder_positions,
inputs_embeds=inputs_embeds)
return hidden_states
def compute_logits(
self,
@ -236,9 +1113,5 @@ class Florence2ForConditionalGeneration(nn.Module):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
skip_prefixes = [
'image_projection', "vision_tower", "image_proj_norm",
"image_pos_embed", "visual_temporal_embed"
]
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)

View File

@ -105,7 +105,6 @@ _TEXT_GENERATION_MODELS = {
# [Encoder-decoder]
"BartModel": ("bart", "BartForConditionalGeneration"),
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
"Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501
}
_EMBEDDING_MODELS = {
@ -182,6 +181,7 @@ _MULTIMODAL_MODELS = {
"Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501
"UltravoxModel": ("ultravox", "UltravoxModel"),
# [Encoder-decoder]
"Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
"WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501
}

View File

@ -1303,6 +1303,14 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
"""
raise NotImplementedError
def create_decoder_prompt(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
) -> Union[str, list[int]]:
"""Create input prompt for the decoder."""
return prompt
def apply(
self,
prompt: Union[str, list[int]],
@ -1323,17 +1331,15 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
hf_processor_mm_kwargs,
)
# We assumed the decoder prompt text is copied from
# the original encoder prompt without extra process
tokenizer = self.info.get_tokenizer()
if isinstance(prompt, str):
decoder_prompt = prompt
decoder_prompt = self.create_decoder_prompt(prompt, mm_data)
if isinstance(decoder_prompt, str):
decoder_prompt_ids = encode_tokens(tokenizer,
prompt,
decoder_prompt,
add_special_tokens=False)
else:
decoder_prompt = decode_tokens(tokenizer, prompt)
decoder_prompt_ids = prompt
decoder_prompt_ids = decoder_prompt
decoder_prompt = decode_tokens(tokenizer, decoder_prompt)
mm_inputs = MultiModalEncDecInputs(
encoder_prompt=encoder_inputs["prompt"],

View File

@ -204,9 +204,11 @@ class MultiModalProfiler(Generic[_I]):
"and/or reduce `mm_counts`.", seq_len, total_len,
total_placeholders_by_modality)
num_tokens_to_pad = max(total_len, seq_len) - total_len
prompt_token_ids.extend([0] * num_tokens_to_pad)
return DummyData(
seq_data=SequenceData.from_prompt_token_counts(
(0, max(seq_len, total_len))),
seq_data=SequenceData.from_seqs(prompt_token_ids),
multi_modal_data=None,
multi_modal_placeholders=None,
)