[VLM] Implement merged multimodal processor for Mllama (#11427)

This commit is contained in:
Isotr0py 2025-02-13 12:26:21 +08:00 committed by GitHub
parent d88c8666a1
commit bc55d13070
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 444 additions and 221 deletions

View File

@ -7,11 +7,11 @@ import torch
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
BatchEncoding)
from vllm import LLM, SamplingParams
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
global_force_attn_backend_context_manager)
from vllm.model_executor.models.mllama import (MLLAMA_IMAGE_TOKEN_ID,
MllamaForConditionalGeneration)
from vllm.model_executor.models.mllama import MllamaForConditionalGeneration
from vllm.multimodal.image import rescale_image_size
from vllm.sequence import SampleLogprobs
@ -21,6 +21,7 @@ from ....utils import large_gpu_test
from ...utils import check_logprobs_close
_LIMIT_IMAGE_PER_PROMPT = 3
MLLAMA_IMAGE_TOKEN_ID = 128256
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN]
@ -396,6 +397,64 @@ def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
)
@large_gpu_test(min_gb=48)
@pytest.mark.core_model
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [32])
def test_explicit_implicit_prompt(
image_assets: _ImageAssets,
model: str,
dtype: str,
max_tokens: int,
):
stop_sign = image_assets[0].pil_image
# yapf: disable
prompts = [
# explicit prompt
{
"encoder_prompt": {
"prompt": "<|image|>",
"multi_modal_data": {"image": stop_sign},
},
"decoder_prompt": {
"prompt_token_ids": [128000, 791, 2262, 315, 279, 2217, 220, 128256, 374], # noqa: E501
}
},
{
"encoder_prompt": "Not <|image|>",
"decoder_prompt": "The color of the sky is blue but sometimes it can also be", # noqa: E501
},
# implicit prompt
{
"prompt": "<|begin_of_text|>The content of the image <|image|> is", # noqa: E501
"multi_modal_data": {"image": stop_sign},
},
{
"prompt": "The color of the sky is blue but sometimes it can also be", # noqa: E501
},
]
# yapf: enable
llm = LLM(
model=model,
dtype=dtype,
max_model_len=4096,
max_num_seqs=2,
tensor_parallel_size=1,
enforce_eager=True,
)
sampling_params = SamplingParams(
temperature=0,
max_tokens=max_tokens,
)
outputs = llm.generate(prompts, sampling_params)
n_prompts = len(prompts)
explicit_outputs = outputs[:n_prompts // 2]
implicit_outputs = outputs[n_prompts // 2:]
for exp_output, imp_output in zip(explicit_outputs, implicit_outputs):
assert exp_output.outputs[0].text == imp_output.outputs[0].text
@large_gpu_test(min_gb=48)
@pytest.mark.core_model
@pytest.mark.parametrize("model", models)
@ -458,6 +517,10 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens,
images=images)
class DummyModel:
image_token_id = MLLAMA_IMAGE_TOKEN_ID
@pytest.mark.core_model
@pytest.mark.parametrize(
"input_indices_and_output",
@ -499,7 +562,7 @@ def test_get_cross_attention_mask(input_indices_and_output) -> None:
use_cuda_graph=False,
)
dummy: dict[str, str] = {}
dummy = DummyModel()
cross_attention_mask, kv_range_for_decode = MllamaForConditionalGeneration\
.get_cross_attention_mask(dummy,
@ -556,7 +619,7 @@ def test_get_full_text_row_masked_out_mask(input_indices) -> None:
use_cuda_graph=False,
)
dummy: dict[str, str] = {}
dummy = DummyModel()
full_text_row_masked_out_mask = MllamaForConditionalGeneration\
.get_full_text_row_masked_out_mask(dummy,

View File

@ -85,6 +85,14 @@ def _test_processing_correctness(
partial(random_audio, rng, min_len=512, max_len=1024, sr=16000),
}
tokenizer_encode_kwargs = {}
if model_config.hf_config.model_type == "mllama":
# For Mllama, tokenizer will always add bos_token at the beginning of
# prompt by default, causing hf_processor outputs incorrect token ids.
# So we need use `add_special_tokens=False` here to leave bos_token
# to be added by the processor.
tokenizer_encode_kwargs = {"add_special_tokens": False}
for batch_idx in range(num_batches):
mm_data = {
k:
@ -122,7 +130,7 @@ def _test_processing_correctness(
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
baseline_tokenized_result = baseline_processor.apply(
tokenizer.encode(prompt),
tokenizer.encode(prompt, **tokenizer_encode_kwargs),
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
@ -131,7 +139,7 @@ def _test_processing_correctness(
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
cached_tokenized_result = cached_processor.apply(
tokenizer.encode(prompt),
tokenizer.encode(prompt, **tokenizer_encode_kwargs),
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
@ -155,6 +163,7 @@ def _test_processing_correctness(
"llava-hf/llava-v1.6-mistral-7b-hf",
"llava-hf/LLaVA-NeXT-Video-7B-hf",
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
"meta-llama/Llama-3.2-11B-Vision-Instruct",
"TIGER-Lab/Mantis-8B-siglip-llama3",
"mistral-community/pixtral-12b",
"openbmb/MiniCPM-o-2_6",

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
from typing import List, Mapping, Optional, Union
from typing import List, Mapping, Optional, Tuple, Union, cast
from typing_extensions import assert_never
@ -9,7 +9,8 @@ from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputs
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs)
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
@ -495,6 +496,51 @@ class InputPreprocessor:
decoder=decoder_inputs,
)
def _separate_enc_dec_inputs_from_mm_processor_outputs(
self,
inputs: SingletonInputs,
decoder_inputs_to_override: Optional[SingletonInputs] = None,
) -> Tuple[SingletonInputs, SingletonInputs]:
"""
For encoder/decoder models only:
Separate Encoder/Decoder inputs from a MultiModalEncDecInputs
"""
encoder_inputs: SingletonInputs
decoder_inputs: SingletonInputs
if inputs["type"] == "multimodal":
# Multimodal data inputs
assert ("encoder_prompt" in inputs
and "encoder_prompt_token_ids" in inputs)
inputs = cast(MultiModalEncDecInputs, inputs)
encoder_inputs = token_inputs(
prompt=inputs["encoder_prompt"],
prompt_token_ids=inputs["encoder_prompt_token_ids"],
)
if decoder_inputs_to_override is not None:
decoder_inputs = MultiModalInputs(
type="multimodal",
prompt=decoder_inputs_to_override.get("prompt", ""),
prompt_token_ids=decoder_inputs_to_override[
"prompt_token_ids"],
mm_kwargs=inputs["mm_kwargs"],
mm_placeholders=inputs["mm_placeholders"],
)
else:
decoder_inputs = MultiModalInputs(
type="multimodal",
prompt=inputs["prompt"],
prompt_token_ids=inputs["prompt_token_ids"],
mm_kwargs=inputs["mm_kwargs"],
mm_placeholders=inputs["mm_placeholders"],
)
elif inputs["type"] == "token":
# Text-only inputs
encoder_inputs = token_inputs(prompt="", prompt_token_ids=[])
decoder_inputs = decoder_inputs_to_override or inputs
else:
assert_never(inputs) # type: ignore[arg-type]
return encoder_inputs, decoder_inputs
def _process_encoder_decoder_prompt(
self,
prompt: PromptType,
@ -539,7 +585,6 @@ class InputPreprocessor:
prompt["encoder_prompt"],
request_id=request_id,
)
if (decoder_input := prompt["decoder_prompt"]) is None:
decoder_inputs = None
else:
@ -547,11 +592,26 @@ class InputPreprocessor:
decoder_input,
request_id=request_id,
)
# For multimodal model, override decoder prompt from processor
# with explicit decoder prompt.
if self.model_config.is_multimodal_model and (
self._can_process_multimodal()):
encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs(
encoder_inputs, decoder_inputs))
else:
encoder_inputs = self._prompt_to_llm_inputs(
inputs = self._prompt_to_llm_inputs(
prompt,
request_id=request_id,
)
if self.model_config.is_multimodal_model and (
self._can_process_multimodal()):
# Encoder-Decoder Multimodal model
encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs(
inputs))
else:
encoder_inputs = inputs
decoder_inputs = None
@ -583,11 +643,27 @@ class InputPreprocessor:
encoder_inputs, decoder_inputs = await asyncio.gather(
encoder_task, decoder_task)
# For multimodal model, override decoder prompt from processor
# with explicit decoder prompt.
if self.model_config.is_multimodal_model and (
self._can_process_multimodal()):
encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs(
encoder_inputs, decoder_inputs))
else:
encoder_inputs = await self._prompt_to_llm_inputs_async(
inputs = await self._prompt_to_llm_inputs_async(
prompt,
request_id=request_id,
)
if self.model_config.is_multimodal_model and (
self._can_process_multimodal()):
# Encoder-Decoder Multimodal model
encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs(
inputs))
else:
encoder_inputs = inputs
decoder_inputs = None

View File

@ -350,7 +350,8 @@ class InputRegistry:
)
processor = mm_registry.create_processor(model_config, tokenizer)
profiler = MultiModalProfiler(processor)
dummy_data = profiler.get_dummy_data(seq_len)
dummy_data = profiler.get_dummy_data(
seq_len, is_encoder_data=is_encoder_data)
else:
model_cls, _ = get_model_architecture(model_config)
if is_encoder_data:

View File

@ -23,14 +23,15 @@ import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers.models.mllama.configuration_mllama as config_mllama
from PIL import Image
from PIL.Image import Image
from torch import nn
from transformers import BatchFeature, MllamaConfig
from transformers.modeling_outputs import (BaseModelOutput,
CausalLMOutputWithPast)
from transformers.models.mllama.image_processing_mllama import (
get_optimal_tiled_canvas)
from transformers.models.mllama.processing_mllama import (
get_cross_attention_token_mask)
MllamaProcessor, get_cross_attention_token_mask)
import vllm.distributed.parallel_state as ps
from vllm.attention import Attention, AttentionMetadata, AttentionType
@ -38,8 +39,6 @@ from vllm.attention.ops.paged_attn import PagedAttention
from vllm.attention.selector import _Backend
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DummyData, EncoderDecoderInputs,
InputContext, TokenInputs, token_inputs)
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -54,8 +53,13 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import SequenceData
from vllm.utils import is_list_of
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataDict, MultiModalDataItems)
from vllm.multimodal.processing import (BaseProcessingInfo,
EncDecMultiModalProcessor,
PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from .clip import CLIPMLP
from .interfaces import SupportsMultiModal
@ -63,8 +67,6 @@ from .llama import LlamaDecoderLayer, LlamaMLP
from .utils import maybe_prefix
logger = init_logger(__name__)
MLLAMA_IMAGE_TOKEN_ID = 128256
MLLAMA_IMAGE_TOKEN = "<|image|>"
class MllamaImagePixelInputs(TypedDict):
@ -81,25 +83,113 @@ class MllamaImagePixelInputs(TypedDict):
# TODO: support LlamaImageEmbeddingInputs
def _get_num_image_in_last_group(prompt_token_ids: List[int]) -> int:
num_images = 0
for token_id in prompt_token_ids[::-1]:
if token_id == MLLAMA_IMAGE_TOKEN_ID:
num_images += 1
elif num_images > 0:
break
return num_images
def calc_token_per_chunk(image_size: int) -> int:
assert image_size % 14 == 0, "chunk size should be multiple of 14"
token_per_chunk = (image_size // 14)**2 + 1
return token_per_chunk
def input_processor_for_mllama(
ctx: InputContext,
inputs: EncoderDecoderInputs,
) -> EncoderDecoderInputs:
# Example input to processor:
class MllamaProcessingInfo(BaseProcessingInfo):
def get_hf_config(self) -> MllamaConfig:
return self.ctx.get_hf_config(MllamaConfig)
def get_hf_processor(self) -> MllamaProcessor:
return self.ctx.get_hf_processor(MllamaProcessor)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_token_per_chunk_from_config(self) -> int:
image_size = self.get_hf_config().vision_config.image_size
return calc_token_per_chunk(image_size)
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
vision_config = self.get_hf_config().vision_config
token_per_chunk = self.get_token_per_chunk_from_config()
mm_max_tokens = vision_config.max_num_tiles * token_per_chunk
return {"image": mm_max_tokens}
def get_num_tiles_per_image(self, image_height: int,
image_width: int) -> int:
vision_config = self.get_hf_config().vision_config
max_num_tiles = vision_config.max_num_tiles
image_size = vision_config.image_size
tiled_height, tiled_width = get_optimal_tiled_canvas(
image_height,
image_width,
max_num_tiles,
tile_size=image_size,
)
num_tiles_height = tiled_height // image_size
num_tiles_width = tiled_width // image_size
return num_tiles_height * num_tiles_width
def get_image_size_with_most_features(self) -> ImageSize:
vision_config = self.get_hf_config().vision_config
image_size = vision_config.image_size
max_num_tiles = vision_config.max_num_tiles
# Result in the max possible feature size (h:w = 16:1)
return ImageSize(height=max_num_tiles * image_size, width=image_size)
class MllamaDummyInputsBuilder(BaseDummyInputsBuilder[MllamaProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)
target_width, target_height = \
self.info.get_image_size_with_most_features()
mm_data = {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
hf_processor = self.info.get_hf_processor()
image_token: str = hf_processor.image_token
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=mm_data,
)
class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
):
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
tokenizer = self.info.get_tokenizer()
if mm_data:
num_tiles = [
self.info.get_num_tiles_per_image(img.height, img.width)
for img in mm_data["images"]
]
processed_outputs = super()._call_hf_processor(
prompt, mm_data, mm_kwargs)
processed_outputs["num_tiles"] = torch.tensor(num_tiles)
for k in ('pixel_values', 'aspect_ratio_ids', "aspect_ratio_mask"):
processed_outputs[k] = processed_outputs[k].squeeze(0)
# Example input to encoder and decoder:
# {
# 'encoder': {
# 'type': 'token',
# 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
# 'prompt_token_ids': [128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
# 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
# },
@ -108,131 +198,76 @@ def input_processor_for_mllama(
# 'prompt_token_ids': [128000],
# },
# }
processed_token_ids = processed_outputs.pop("input_ids")
start_idx, end_idx = 0, processed_token_ids.size(1)
processed_prompt_text = tokenizer.decode(processed_token_ids[0])
# move encoder prompt to decoder
dec_inputs = TokenInputs(**inputs["encoder"])
hf_processor = self.info.get_hf_processor()
bos_token = hf_processor.bos_token
# Remove the bos_token from the start of prompt,
# because we all know there would be image_token.
if processed_prompt_text.startswith(bos_token):
start_idx += 1
# Remove the bos_token from the end of prompt,
# because text is empty in this case.
if processed_prompt_text.endswith(bos_token):
end_idx -= 1
processed_outputs[
"input_ids"] = processed_token_ids[:, start_idx:end_idx]
else:
processed_outputs = tokenizer(prompt,
add_special_tokens=False,
return_tensors="pt")
return processed_outputs
multi_modal_data = dec_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
# text-only
return EncoderDecoderInputs(
encoder=token_inputs([]),
decoder=dec_inputs,
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
aspect_ratio_ids=MultiModalFieldConfig.batched("image"),
aspect_ratio_mask=MultiModalFieldConfig.batched("image"),
num_tiles=MultiModalFieldConfig.batched("image"),
)
image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
image_data = [image_data]
def create_encoder_prompt(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
) -> Union[str, list[int]]:
data = mm_data.get("image", [])
num_images = 1 if isinstance(data, Image) else len(data)
image_token_id = self.info.get_hf_config().image_token_index
return [image_token_id] * num_images
assert is_list_of(image_data, Image.Image)
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
token_per_chunk = self.info.get_token_per_chunk_from_config()
image_token_id = self.info.get_hf_config().image_token_index
num_image_tokens = dec_inputs['prompt_token_ids'].count(
MLLAMA_IMAGE_TOKEN_ID)
if num_image_tokens != len(image_data):
raise ValueError(
f"The number of image tokens ({num_image_tokens}) must be"
f" the same as the number of images ({len(image_data)})")
# Since only the last group of consecutive images
# are attended by the decoded tokens, we only need to
# get the number of tiles for those images.
num_decode_images = _get_num_image_in_last_group(
dec_inputs["prompt_token_ids"])
hf_config = ctx.model_config.hf_config
vision_config = hf_config.vision_config
num_tiles = 0
for image in image_data[::-1]:
width, height = image.size
tile_size = vision_config.image_size
canvas_height, canvas_width = get_optimal_tiled_canvas(
image_height=height,
image_width=width,
max_image_tiles=vision_config.max_num_tiles,
tile_size=tile_size,
def get_replacement_mllama(item_idx):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
num_tile = self.info.get_num_tiles_per_image(
image_height=image_size.height,
image_width=image_size.width,
)
num_tiles_height = canvas_height // tile_size
num_tiles_width = canvas_width // tile_size
num_tiles += num_tiles_height * num_tiles_width
num_decode_images -= 1
if num_decode_images == 0:
break
num_tokens = num_tile * token_per_chunk
return [image_token_id] * num_tokens
# Set encoder prompt length based on the number of tiles.
# This tells the block manager to allocate correct number
# of slots for encoder tokens.
assert vision_config.image_size % 14 == 0, \
"chunk size should be multiple of 14"
token_per_chunk = (vision_config.image_size // 14)**2 + 1
num_tokens = num_tiles * token_per_chunk
# Example output from processor:
# {
# 'encoder': {
# 'type': 'token',
# 'prompt_token_ids': [128256, 128256, ..., 128256],
# 'prompt': '<|image|><|image|>...<|image|>',
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
# },
# 'decoder': {
# 'type': 'token',
# 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
# 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
# },
# }
return EncoderDecoderInputs(
encoder=token_inputs(
prompt_token_ids=[MLLAMA_IMAGE_TOKEN_ID] * num_tokens,
prompt=MLLAMA_IMAGE_TOKEN * num_tokens,
multi_modal_data=multi_modal_data,
),
decoder=dec_inputs,
return [
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=get_replacement_mllama,
)
def get_max_mllama_image_tokens(ctx: InputContext) -> int:
hf_config = ctx.model_config.hf_config
token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1
return hf_config.vision_config.max_num_tiles * token_per_chunk
def dummy_decoder_seq_data(seq_len: int, num_images: int):
# <|image|> * num_images + 0 * (seq_len - num_images)
assert seq_len >= num_images, \
"seq_len should be greater than or equal to num_images"
return SequenceData.from_prompt_token_counts(
(MLLAMA_IMAGE_TOKEN_ID, num_images),
(0, seq_len - num_images),
)
def dummy_encoder_seq_data(ctx: InputContext, num_images: int):
num_tokens = get_max_mllama_image_tokens(ctx) * num_images
return SequenceData.from_prompt_token_counts(
(MLLAMA_IMAGE_TOKEN_ID, num_tokens))
def dummy_image(num_images: int, ):
width = height = 1024
image = Image.new("RGB", (width, height), color=0)
return {"image": image if num_images == 1 else [image] * num_images}
def dummy_decoder_data_for_mllama(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]
return DummyData(dummy_decoder_seq_data(seq_len, num_images))
def dummy_encoder_data_for_mllama(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]
return DummyData(dummy_encoder_seq_data(ctx, num_images),
dummy_image(num_images))
]
def _prepare_aspect_ratio_attention_mask(
@ -1107,11 +1142,9 @@ class MllamaForCausalLM(nn.Module):
return hidden_states
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_mllama_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_decoder_data_for_mllama)
@INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_mllama)
@INPUT_REGISTRY.register_input_processor(input_processor_for_mllama)
@MULTIMODAL_REGISTRY.register_processor(MllamaMultiModalProcessor,
info=MllamaProcessingInfo,
dummy_inputs=MllamaDummyInputsBuilder)
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
@ -1120,7 +1153,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
config: MllamaConfig = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.quant_config = quant_config
self.vocab_size = config.text_config.vocab_size
@ -1130,6 +1163,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
self.pad_token_id = \
config.pad_token_id if config.pad_token_id is not None else -1
self.image_size = config.vision_config.image_size
self.image_token_id = config.image_token_index
self.vision_model = MllamaVisionModel(config.vision_config,
quant_config,
@ -1204,48 +1238,12 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
if pixel_values is not None:
assert aspect_ratio_ids is not None
assert aspect_ratio_mask is not None
max_num_images = max([len(x[0]) for x in pixel_values])
if max_num_images == 0:
raise ValueError("No images provided.")
max_num_tiles = max(
max([len(x) for x in y[0]]) for y in pixel_values)
device = next(self.multi_modal_projector.parameters()).device
bsz = len(pixel_values)
out_num_tiles = []
out_images = torch.zeros(
bsz,
max_num_images,
max_num_tiles,
3,
self.image_size,
self.image_size,
dtype=torch.float32,
device=device,
)
out_ar_ids = torch.ones(bsz,
max_num_images,
dtype=torch.int64,
device=device)
out_ar_mask = torch.zeros(bsz,
max_num_images,
max_num_tiles,
dtype=torch.int64,
device=device)
for b in range(len(pixel_values)):
_num_tiles = []
for i in range(len(pixel_values[b][0])):
img = pixel_values[b][0][i]
out_images[b, i, :img.shape[0]] = img
out_ar_ids[b, i] = aspect_ratio_ids[b][0][i]
out_ar_mask[b, i] = aspect_ratio_mask[b][0][i]
_num_tiles.append(img.shape[0])
out_num_tiles.append(_num_tiles)
return MllamaImagePixelInputs(
type="pixel_values",
data=out_images,
aspect_ratio_ids=out_ar_ids,
aspect_ratio_mask=out_ar_mask,
data=pixel_values,
aspect_ratio_ids=aspect_ratio_ids,
aspect_ratio_mask=aspect_ratio_mask,
)
if image_embeds is not None:
@ -1312,7 +1310,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
batch_token_ids.append(token_ids[start:start + seq_len])
start += seq_len
sparse_mask = [
get_cross_attention_token_mask(t, MLLAMA_IMAGE_TOKEN_ID)
get_cross_attention_token_mask(t, self.image_token_id)
for t in batch_token_ids
]
@ -1384,8 +1382,8 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
# block manager to allocate blocks for those images only.
# See input_processor_for_mllama() for more details.
num_tiles_tensor = kwargs.pop("num_tiles")
num_tiles = [t[0].tolist() for t in num_tiles_tensor]
num_tokens_per_tile = (self.image_size // 14)**2 + 1
num_tiles = [t.tolist() for t in num_tiles_tensor]
num_tokens_per_tile = calc_token_per_chunk(self.image_size)
actual_encoder_seq_lens = [
sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles
]

View File

@ -739,3 +739,19 @@ class MultiModalInputs(TypedDict):
For each modality, information about the placeholder tokens in
:code:`prompt_token_ids`.
"""
class MultiModalEncDecInputs(MultiModalInputs):
"""
Represents the outputs of :class:`vllm.multimodal.EncDecMultiModalProcessor`
ready to be passed to vLLM internals.
"""
encoder_prompt: str
"""The processed encoder prompt text."""
encoder_prompt_token_ids: list[int]
"""The processed token IDs of the encoder prompt."""
encoder_token_type_ids: NotRequired[list[int]]
"""The token type IDs of the encoder prompt."""

View File

@ -20,9 +20,9 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
from vllm.utils import LRUCache, flatten_2d_lists, full_groupby
from .hasher import MultiModalHasher
from .inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargs, MultiModalKwargsItem,
PlaceholderRange)
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs,
MultiModalKwargsItem, PlaceholderRange)
from .parse import MultiModalDataItems, MultiModalDataParser
if TYPE_CHECKING:
@ -1293,3 +1293,57 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_hashes=mm_hashes,
mm_placeholders=mm_placeholder_ranges,
)
class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
@abstractmethod
def create_encoder_prompt(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
) -> Union[str, list[int]]:
"""Create input prompt for the encoder."""
raise NotImplementedError
def apply(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalEncDecInputs:
"""
Process multi-modal inputs to be used in vLLM.
The main processing steps are modified to fit encoder-decoder model:
1. Create encoder prompt from input prompt text.
2. Apply the HF processor on encoder prompt.
3. Copy the input prompt text as decoder prompt inputs.
"""
encoder_prompt = self.create_encoder_prompt(prompt, mm_data)
encoder_inputs = super().apply(
encoder_prompt,
mm_data,
hf_processor_mm_kwargs,
)
# We assumed the decoder prompt text is copied from
# the original encoder prompt without extra process
tokenizer = self.info.get_tokenizer()
if isinstance(prompt, str):
decoder_prompt = prompt
decoder_prompt_ids = encode_tokens(tokenizer,
prompt,
add_special_tokens=False)
else:
decoder_prompt = decode_tokens(tokenizer, prompt)
decoder_prompt_ids = prompt
mm_inputs = MultiModalEncDecInputs(
encoder_prompt=encoder_inputs["prompt"],
encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"],
**encoder_inputs)
mm_inputs.update({
"prompt": decoder_prompt,
"prompt_token_ids": decoder_prompt_ids
})
return mm_inputs

View File

@ -144,7 +144,11 @@ class MultiModalProfiler(Generic[_I]):
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
)
def get_dummy_data(self, seq_len: int) -> DummyData:
def get_dummy_data(
self,
seq_len: int,
is_encoder_data: bool = False,
) -> DummyData:
# Avoid circular import
from vllm.sequence import SequenceData
@ -183,16 +187,18 @@ class MultiModalProfiler(Generic[_I]):
total_len = len(prompt_token_ids)
# V0 does not support chunked prefill.
if total_len > seq_len and not envs.VLLM_USE_V1:
if (total_len > seq_len and not envs.VLLM_USE_V1) or is_encoder_data:
if total_len > seq_len:
logger.warning(
"The context length (%d) of the model is too short "
"to hold the multi-modal embeddings in the worst case "
"(%d tokens in total, out of which %s are reserved for "
"multi-modal embeddings). This may cause certain multi-modal "
"inputs to fail during inference, even when the input text is "
"short. To avoid this, you should increase `max_model_len`, "
"reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len,
total_len, total_placeholders_by_modality)
"multi-modal embeddings). This may cause certain "
"multi-modal inputs to fail during inference, even when "
"the input text is short. To avoid this, you should "
"increase `max_model_len`, reduce `max_num_seqs`, "
"and/or reduce `mm_counts`.", seq_len, total_len,
total_placeholders_by_modality)
return DummyData(
seq_data=SequenceData.from_prompt_token_counts((0, seq_len)),