mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:35:17 +08:00
[VLM] Implement merged multimodal processor for Mllama (#11427)
This commit is contained in:
parent
d88c8666a1
commit
bc55d13070
@ -7,11 +7,11 @@ import torch
|
||||
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
|
||||
BatchEncoding)
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
|
||||
global_force_attn_backend_context_manager)
|
||||
from vllm.model_executor.models.mllama import (MLLAMA_IMAGE_TOKEN_ID,
|
||||
MllamaForConditionalGeneration)
|
||||
from vllm.model_executor.models.mllama import MllamaForConditionalGeneration
|
||||
from vllm.multimodal.image import rescale_image_size
|
||||
from vllm.sequence import SampleLogprobs
|
||||
|
||||
@ -21,6 +21,7 @@ from ....utils import large_gpu_test
|
||||
from ...utils import check_logprobs_close
|
||||
|
||||
_LIMIT_IMAGE_PER_PROMPT = 3
|
||||
MLLAMA_IMAGE_TOKEN_ID = 128256
|
||||
|
||||
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN]
|
||||
|
||||
@ -396,6 +397,64 @@ def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
|
||||
)
|
||||
|
||||
|
||||
@large_gpu_test(min_gb=48)
|
||||
@pytest.mark.core_model
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
def test_explicit_implicit_prompt(
|
||||
image_assets: _ImageAssets,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
):
|
||||
stop_sign = image_assets[0].pil_image
|
||||
# yapf: disable
|
||||
prompts = [
|
||||
# explicit prompt
|
||||
{
|
||||
"encoder_prompt": {
|
||||
"prompt": "<|image|>",
|
||||
"multi_modal_data": {"image": stop_sign},
|
||||
},
|
||||
"decoder_prompt": {
|
||||
"prompt_token_ids": [128000, 791, 2262, 315, 279, 2217, 220, 128256, 374], # noqa: E501
|
||||
}
|
||||
},
|
||||
{
|
||||
"encoder_prompt": "Not <|image|>",
|
||||
"decoder_prompt": "The color of the sky is blue but sometimes it can also be", # noqa: E501
|
||||
},
|
||||
# implicit prompt
|
||||
{
|
||||
"prompt": "<|begin_of_text|>The content of the image <|image|> is", # noqa: E501
|
||||
"multi_modal_data": {"image": stop_sign},
|
||||
},
|
||||
{
|
||||
"prompt": "The color of the sky is blue but sometimes it can also be", # noqa: E501
|
||||
},
|
||||
]
|
||||
# yapf: enable
|
||||
llm = LLM(
|
||||
model=model,
|
||||
dtype=dtype,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
tensor_parallel_size=1,
|
||||
enforce_eager=True,
|
||||
)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
n_prompts = len(prompts)
|
||||
explicit_outputs = outputs[:n_prompts // 2]
|
||||
implicit_outputs = outputs[n_prompts // 2:]
|
||||
for exp_output, imp_output in zip(explicit_outputs, implicit_outputs):
|
||||
assert exp_output.outputs[0].text == imp_output.outputs[0].text
|
||||
|
||||
|
||||
@large_gpu_test(min_gb=48)
|
||||
@pytest.mark.core_model
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@ -458,6 +517,10 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens,
|
||||
images=images)
|
||||
|
||||
|
||||
class DummyModel:
|
||||
image_token_id = MLLAMA_IMAGE_TOKEN_ID
|
||||
|
||||
|
||||
@pytest.mark.core_model
|
||||
@pytest.mark.parametrize(
|
||||
"input_indices_and_output",
|
||||
@ -499,7 +562,7 @@ def test_get_cross_attention_mask(input_indices_and_output) -> None:
|
||||
use_cuda_graph=False,
|
||||
)
|
||||
|
||||
dummy: dict[str, str] = {}
|
||||
dummy = DummyModel()
|
||||
|
||||
cross_attention_mask, kv_range_for_decode = MllamaForConditionalGeneration\
|
||||
.get_cross_attention_mask(dummy,
|
||||
@ -556,7 +619,7 @@ def test_get_full_text_row_masked_out_mask(input_indices) -> None:
|
||||
use_cuda_graph=False,
|
||||
)
|
||||
|
||||
dummy: dict[str, str] = {}
|
||||
dummy = DummyModel()
|
||||
|
||||
full_text_row_masked_out_mask = MllamaForConditionalGeneration\
|
||||
.get_full_text_row_masked_out_mask(dummy,
|
||||
|
||||
@ -85,6 +85,14 @@ def _test_processing_correctness(
|
||||
partial(random_audio, rng, min_len=512, max_len=1024, sr=16000),
|
||||
}
|
||||
|
||||
tokenizer_encode_kwargs = {}
|
||||
if model_config.hf_config.model_type == "mllama":
|
||||
# For Mllama, tokenizer will always add bos_token at the beginning of
|
||||
# prompt by default, causing hf_processor outputs incorrect token ids.
|
||||
# So we need use `add_special_tokens=False` here to leave bos_token
|
||||
# to be added by the processor.
|
||||
tokenizer_encode_kwargs = {"add_special_tokens": False}
|
||||
|
||||
for batch_idx in range(num_batches):
|
||||
mm_data = {
|
||||
k:
|
||||
@ -122,7 +130,7 @@ def _test_processing_correctness(
|
||||
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
|
||||
|
||||
baseline_tokenized_result = baseline_processor.apply(
|
||||
tokenizer.encode(prompt),
|
||||
tokenizer.encode(prompt, **tokenizer_encode_kwargs),
|
||||
mm_data=mm_data,
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
@ -131,7 +139,7 @@ def _test_processing_correctness(
|
||||
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
|
||||
|
||||
cached_tokenized_result = cached_processor.apply(
|
||||
tokenizer.encode(prompt),
|
||||
tokenizer.encode(prompt, **tokenizer_encode_kwargs),
|
||||
mm_data=mm_data,
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
@ -155,6 +163,7 @@ def _test_processing_correctness(
|
||||
"llava-hf/llava-v1.6-mistral-7b-hf",
|
||||
"llava-hf/LLaVA-NeXT-Video-7B-hf",
|
||||
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
|
||||
"meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||
"TIGER-Lab/Mantis-8B-siglip-llama3",
|
||||
"mistral-community/pixtral-12b",
|
||||
"openbmb/MiniCPM-o-2_6",
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import asyncio
|
||||
from typing import List, Mapping, Optional, Union
|
||||
from typing import List, Mapping, Optional, Tuple, Union, cast
|
||||
|
||||
from typing_extensions import assert_never
|
||||
|
||||
@ -9,7 +9,8 @@ from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputs
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
|
||||
MultiModalInputs)
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
||||
|
||||
@ -495,6 +496,51 @@ class InputPreprocessor:
|
||||
decoder=decoder_inputs,
|
||||
)
|
||||
|
||||
def _separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||
self,
|
||||
inputs: SingletonInputs,
|
||||
decoder_inputs_to_override: Optional[SingletonInputs] = None,
|
||||
) -> Tuple[SingletonInputs, SingletonInputs]:
|
||||
"""
|
||||
For encoder/decoder models only:
|
||||
Separate Encoder/Decoder inputs from a MultiModalEncDecInputs
|
||||
"""
|
||||
encoder_inputs: SingletonInputs
|
||||
decoder_inputs: SingletonInputs
|
||||
if inputs["type"] == "multimodal":
|
||||
# Multimodal data inputs
|
||||
assert ("encoder_prompt" in inputs
|
||||
and "encoder_prompt_token_ids" in inputs)
|
||||
inputs = cast(MultiModalEncDecInputs, inputs)
|
||||
encoder_inputs = token_inputs(
|
||||
prompt=inputs["encoder_prompt"],
|
||||
prompt_token_ids=inputs["encoder_prompt_token_ids"],
|
||||
)
|
||||
if decoder_inputs_to_override is not None:
|
||||
decoder_inputs = MultiModalInputs(
|
||||
type="multimodal",
|
||||
prompt=decoder_inputs_to_override.get("prompt", ""),
|
||||
prompt_token_ids=decoder_inputs_to_override[
|
||||
"prompt_token_ids"],
|
||||
mm_kwargs=inputs["mm_kwargs"],
|
||||
mm_placeholders=inputs["mm_placeholders"],
|
||||
)
|
||||
else:
|
||||
decoder_inputs = MultiModalInputs(
|
||||
type="multimodal",
|
||||
prompt=inputs["prompt"],
|
||||
prompt_token_ids=inputs["prompt_token_ids"],
|
||||
mm_kwargs=inputs["mm_kwargs"],
|
||||
mm_placeholders=inputs["mm_placeholders"],
|
||||
)
|
||||
elif inputs["type"] == "token":
|
||||
# Text-only inputs
|
||||
encoder_inputs = token_inputs(prompt="", prompt_token_ids=[])
|
||||
decoder_inputs = decoder_inputs_to_override or inputs
|
||||
else:
|
||||
assert_never(inputs) # type: ignore[arg-type]
|
||||
return encoder_inputs, decoder_inputs
|
||||
|
||||
def _process_encoder_decoder_prompt(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
@ -539,7 +585,6 @@ class InputPreprocessor:
|
||||
prompt["encoder_prompt"],
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
if (decoder_input := prompt["decoder_prompt"]) is None:
|
||||
decoder_inputs = None
|
||||
else:
|
||||
@ -547,13 +592,28 @@ class InputPreprocessor:
|
||||
decoder_input,
|
||||
request_id=request_id,
|
||||
)
|
||||
# For multimodal model, override decoder prompt from processor
|
||||
# with explicit decoder prompt.
|
||||
if self.model_config.is_multimodal_model and (
|
||||
self._can_process_multimodal()):
|
||||
encoder_inputs, decoder_inputs = (
|
||||
self._separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||
encoder_inputs, decoder_inputs))
|
||||
else:
|
||||
encoder_inputs = self._prompt_to_llm_inputs(
|
||||
inputs = self._prompt_to_llm_inputs(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
)
|
||||
if self.model_config.is_multimodal_model and (
|
||||
self._can_process_multimodal()):
|
||||
# Encoder-Decoder Multimodal model
|
||||
encoder_inputs, decoder_inputs = (
|
||||
self._separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||
inputs))
|
||||
else:
|
||||
encoder_inputs = inputs
|
||||
|
||||
decoder_inputs = None
|
||||
decoder_inputs = None
|
||||
|
||||
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
|
||||
|
||||
@ -583,13 +643,29 @@ class InputPreprocessor:
|
||||
|
||||
encoder_inputs, decoder_inputs = await asyncio.gather(
|
||||
encoder_task, decoder_task)
|
||||
|
||||
# For multimodal model, override decoder prompt from processor
|
||||
# with explicit decoder prompt.
|
||||
if self.model_config.is_multimodal_model and (
|
||||
self._can_process_multimodal()):
|
||||
encoder_inputs, decoder_inputs = (
|
||||
self._separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||
encoder_inputs, decoder_inputs))
|
||||
else:
|
||||
encoder_inputs = await self._prompt_to_llm_inputs_async(
|
||||
inputs = await self._prompt_to_llm_inputs_async(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
)
|
||||
if self.model_config.is_multimodal_model and (
|
||||
self._can_process_multimodal()):
|
||||
# Encoder-Decoder Multimodal model
|
||||
encoder_inputs, decoder_inputs = (
|
||||
self._separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||
inputs))
|
||||
else:
|
||||
encoder_inputs = inputs
|
||||
|
||||
decoder_inputs = None
|
||||
decoder_inputs = None
|
||||
|
||||
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
|
||||
|
||||
|
||||
@ -350,7 +350,8 @@ class InputRegistry:
|
||||
)
|
||||
processor = mm_registry.create_processor(model_config, tokenizer)
|
||||
profiler = MultiModalProfiler(processor)
|
||||
dummy_data = profiler.get_dummy_data(seq_len)
|
||||
dummy_data = profiler.get_dummy_data(
|
||||
seq_len, is_encoder_data=is_encoder_data)
|
||||
else:
|
||||
model_cls, _ = get_model_architecture(model_config)
|
||||
if is_encoder_data:
|
||||
|
||||
@ -23,14 +23,15 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
import transformers.models.mllama.configuration_mllama as config_mllama
|
||||
from PIL import Image
|
||||
from PIL.Image import Image
|
||||
from torch import nn
|
||||
from transformers import BatchFeature, MllamaConfig
|
||||
from transformers.modeling_outputs import (BaseModelOutput,
|
||||
CausalLMOutputWithPast)
|
||||
from transformers.models.mllama.image_processing_mllama import (
|
||||
get_optimal_tiled_canvas)
|
||||
from transformers.models.mllama.processing_mllama import (
|
||||
get_cross_attention_token_mask)
|
||||
MllamaProcessor, get_cross_attention_token_mask)
|
||||
|
||||
import vllm.distributed.parallel_state as ps
|
||||
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
||||
@ -38,8 +39,6 @@ from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.attention.selector import _Backend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import (INPUT_REGISTRY, DummyData, EncoderDecoderInputs,
|
||||
InputContext, TokenInputs, token_inputs)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@ -54,8 +53,13 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sequence import SequenceData
|
||||
from vllm.utils import is_list_of
|
||||
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
|
||||
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
||||
MultiModalDataDict, MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseProcessingInfo,
|
||||
EncDecMultiModalProcessor,
|
||||
PromptReplacement)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
|
||||
from .clip import CLIPMLP
|
||||
from .interfaces import SupportsMultiModal
|
||||
@ -63,8 +67,6 @@ from .llama import LlamaDecoderLayer, LlamaMLP
|
||||
from .utils import maybe_prefix
|
||||
|
||||
logger = init_logger(__name__)
|
||||
MLLAMA_IMAGE_TOKEN_ID = 128256
|
||||
MLLAMA_IMAGE_TOKEN = "<|image|>"
|
||||
|
||||
|
||||
class MllamaImagePixelInputs(TypedDict):
|
||||
@ -81,158 +83,191 @@ class MllamaImagePixelInputs(TypedDict):
|
||||
# TODO: support LlamaImageEmbeddingInputs
|
||||
|
||||
|
||||
def _get_num_image_in_last_group(prompt_token_ids: List[int]) -> int:
|
||||
num_images = 0
|
||||
for token_id in prompt_token_ids[::-1]:
|
||||
if token_id == MLLAMA_IMAGE_TOKEN_ID:
|
||||
num_images += 1
|
||||
elif num_images > 0:
|
||||
break
|
||||
return num_images
|
||||
def calc_token_per_chunk(image_size: int) -> int:
|
||||
assert image_size % 14 == 0, "chunk size should be multiple of 14"
|
||||
token_per_chunk = (image_size // 14)**2 + 1
|
||||
return token_per_chunk
|
||||
|
||||
|
||||
def input_processor_for_mllama(
|
||||
ctx: InputContext,
|
||||
inputs: EncoderDecoderInputs,
|
||||
) -> EncoderDecoderInputs:
|
||||
# Example input to processor:
|
||||
# {
|
||||
# 'encoder': {
|
||||
# 'type': 'token',
|
||||
# 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
|
||||
# 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501
|
||||
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
|
||||
# },
|
||||
# 'decoder': {
|
||||
# 'type': 'token',
|
||||
# 'prompt_token_ids': [128000],
|
||||
# },
|
||||
# }
|
||||
class MllamaProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
# move encoder prompt to decoder
|
||||
dec_inputs = TokenInputs(**inputs["encoder"])
|
||||
def get_hf_config(self) -> MllamaConfig:
|
||||
return self.ctx.get_hf_config(MllamaConfig)
|
||||
|
||||
multi_modal_data = dec_inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
# text-only
|
||||
return EncoderDecoderInputs(
|
||||
encoder=token_inputs([]),
|
||||
decoder=dec_inputs,
|
||||
def get_hf_processor(self) -> MllamaProcessor:
|
||||
return self.ctx.get_hf_processor(MllamaProcessor)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
|
||||
def get_token_per_chunk_from_config(self) -> int:
|
||||
image_size = self.get_hf_config().vision_config.image_size
|
||||
return calc_token_per_chunk(image_size)
|
||||
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
vision_config = self.get_hf_config().vision_config
|
||||
token_per_chunk = self.get_token_per_chunk_from_config()
|
||||
mm_max_tokens = vision_config.max_num_tiles * token_per_chunk
|
||||
return {"image": mm_max_tokens}
|
||||
|
||||
def get_num_tiles_per_image(self, image_height: int,
|
||||
image_width: int) -> int:
|
||||
vision_config = self.get_hf_config().vision_config
|
||||
max_num_tiles = vision_config.max_num_tiles
|
||||
image_size = vision_config.image_size
|
||||
tiled_height, tiled_width = get_optimal_tiled_canvas(
|
||||
image_height,
|
||||
image_width,
|
||||
max_num_tiles,
|
||||
tile_size=image_size,
|
||||
)
|
||||
num_tiles_height = tiled_height // image_size
|
||||
num_tiles_width = tiled_width // image_size
|
||||
return num_tiles_height * num_tiles_width
|
||||
|
||||
def get_image_size_with_most_features(self) -> ImageSize:
|
||||
vision_config = self.get_hf_config().vision_config
|
||||
image_size = vision_config.image_size
|
||||
max_num_tiles = vision_config.max_num_tiles
|
||||
# Result in the max possible feature size (h:w = 16:1)
|
||||
return ImageSize(height=max_num_tiles * image_size, width=image_size)
|
||||
|
||||
|
||||
class MllamaDummyInputsBuilder(BaseDummyInputsBuilder[MllamaProcessingInfo]):
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
hf_processor = self.info.get_hf_processor()
|
||||
image_token: str = hf_processor.image_token
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text=image_token * num_images,
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
image_data = multi_modal_data["image"]
|
||||
if isinstance(image_data, Image.Image):
|
||||
image_data = [image_data]
|
||||
|
||||
assert is_list_of(image_data, Image.Image)
|
||||
class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
|
||||
):
|
||||
|
||||
num_image_tokens = dec_inputs['prompt_token_ids'].count(
|
||||
MLLAMA_IMAGE_TOKEN_ID)
|
||||
if num_image_tokens != len(image_data):
|
||||
raise ValueError(
|
||||
f"The number of image tokens ({num_image_tokens}) must be"
|
||||
f" the same as the number of images ({len(image_data)})")
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
if mm_data:
|
||||
num_tiles = [
|
||||
self.info.get_num_tiles_per_image(img.height, img.width)
|
||||
for img in mm_data["images"]
|
||||
]
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
prompt, mm_data, mm_kwargs)
|
||||
processed_outputs["num_tiles"] = torch.tensor(num_tiles)
|
||||
for k in ('pixel_values', 'aspect_ratio_ids', "aspect_ratio_mask"):
|
||||
processed_outputs[k] = processed_outputs[k].squeeze(0)
|
||||
# Example input to encoder and decoder:
|
||||
# {
|
||||
# 'encoder': {
|
||||
# 'type': 'token',
|
||||
# 'prompt_token_ids': [128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
|
||||
# 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501
|
||||
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
|
||||
# },
|
||||
# 'decoder': {
|
||||
# 'type': 'token',
|
||||
# 'prompt_token_ids': [128000],
|
||||
# },
|
||||
# }
|
||||
processed_token_ids = processed_outputs.pop("input_ids")
|
||||
start_idx, end_idx = 0, processed_token_ids.size(1)
|
||||
processed_prompt_text = tokenizer.decode(processed_token_ids[0])
|
||||
|
||||
# Since only the last group of consecutive images
|
||||
# are attended by the decoded tokens, we only need to
|
||||
# get the number of tiles for those images.
|
||||
num_decode_images = _get_num_image_in_last_group(
|
||||
dec_inputs["prompt_token_ids"])
|
||||
hf_processor = self.info.get_hf_processor()
|
||||
bos_token = hf_processor.bos_token
|
||||
# Remove the bos_token from the start of prompt,
|
||||
# because we all know there would be image_token.
|
||||
if processed_prompt_text.startswith(bos_token):
|
||||
start_idx += 1
|
||||
# Remove the bos_token from the end of prompt,
|
||||
# because text is empty in this case.
|
||||
if processed_prompt_text.endswith(bos_token):
|
||||
end_idx -= 1
|
||||
processed_outputs[
|
||||
"input_ids"] = processed_token_ids[:, start_idx:end_idx]
|
||||
else:
|
||||
processed_outputs = tokenizer(prompt,
|
||||
add_special_tokens=False,
|
||||
return_tensors="pt")
|
||||
return processed_outputs
|
||||
|
||||
hf_config = ctx.model_config.hf_config
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
num_tiles = 0
|
||||
for image in image_data[::-1]:
|
||||
width, height = image.size
|
||||
tile_size = vision_config.image_size
|
||||
canvas_height, canvas_width = get_optimal_tiled_canvas(
|
||||
image_height=height,
|
||||
image_width=width,
|
||||
max_image_tiles=vision_config.max_num_tiles,
|
||||
tile_size=tile_size,
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
aspect_ratio_ids=MultiModalFieldConfig.batched("image"),
|
||||
aspect_ratio_mask=MultiModalFieldConfig.batched("image"),
|
||||
num_tiles=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
num_tiles_height = canvas_height // tile_size
|
||||
num_tiles_width = canvas_width // tile_size
|
||||
num_tiles += num_tiles_height * num_tiles_width
|
||||
num_decode_images -= 1
|
||||
if num_decode_images == 0:
|
||||
break
|
||||
|
||||
# Set encoder prompt length based on the number of tiles.
|
||||
# This tells the block manager to allocate correct number
|
||||
# of slots for encoder tokens.
|
||||
assert vision_config.image_size % 14 == 0, \
|
||||
"chunk size should be multiple of 14"
|
||||
token_per_chunk = (vision_config.image_size // 14)**2 + 1
|
||||
num_tokens = num_tiles * token_per_chunk
|
||||
def create_encoder_prompt(
|
||||
self,
|
||||
prompt: Union[str, list[int]],
|
||||
mm_data: MultiModalDataDict,
|
||||
) -> Union[str, list[int]]:
|
||||
data = mm_data.get("image", [])
|
||||
num_images = 1 if isinstance(data, Image) else len(data)
|
||||
image_token_id = self.info.get_hf_config().image_token_index
|
||||
return [image_token_id] * num_images
|
||||
|
||||
# Example output from processor:
|
||||
# {
|
||||
# 'encoder': {
|
||||
# 'type': 'token',
|
||||
# 'prompt_token_ids': [128256, 128256, ..., 128256],
|
||||
# 'prompt': '<|image|><|image|>...<|image|>',
|
||||
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
|
||||
# },
|
||||
# 'decoder': {
|
||||
# 'type': 'token',
|
||||
# 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
|
||||
# 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501
|
||||
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
|
||||
# },
|
||||
# }
|
||||
return EncoderDecoderInputs(
|
||||
encoder=token_inputs(
|
||||
prompt_token_ids=[MLLAMA_IMAGE_TOKEN_ID] * num_tokens,
|
||||
prompt=MLLAMA_IMAGE_TOKEN * num_tokens,
|
||||
multi_modal_data=multi_modal_data,
|
||||
),
|
||||
decoder=dec_inputs,
|
||||
)
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
token_per_chunk = self.info.get_token_per_chunk_from_config()
|
||||
image_token_id = self.info.get_hf_config().image_token_index
|
||||
|
||||
def get_replacement_mllama(item_idx):
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
image_size = images.get_image_size(item_idx)
|
||||
num_tile = self.info.get_num_tiles_per_image(
|
||||
image_height=image_size.height,
|
||||
image_width=image_size.width,
|
||||
)
|
||||
num_tokens = num_tile * token_per_chunk
|
||||
return [image_token_id] * num_tokens
|
||||
|
||||
def get_max_mllama_image_tokens(ctx: InputContext) -> int:
|
||||
hf_config = ctx.model_config.hf_config
|
||||
token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1
|
||||
return hf_config.vision_config.max_num_tiles * token_per_chunk
|
||||
|
||||
|
||||
def dummy_decoder_seq_data(seq_len: int, num_images: int):
|
||||
# <|image|> * num_images + 0 * (seq_len - num_images)
|
||||
assert seq_len >= num_images, \
|
||||
"seq_len should be greater than or equal to num_images"
|
||||
|
||||
return SequenceData.from_prompt_token_counts(
|
||||
(MLLAMA_IMAGE_TOKEN_ID, num_images),
|
||||
(0, seq_len - num_images),
|
||||
)
|
||||
|
||||
|
||||
def dummy_encoder_seq_data(ctx: InputContext, num_images: int):
|
||||
num_tokens = get_max_mllama_image_tokens(ctx) * num_images
|
||||
|
||||
return SequenceData.from_prompt_token_counts(
|
||||
(MLLAMA_IMAGE_TOKEN_ID, num_tokens))
|
||||
|
||||
|
||||
def dummy_image(num_images: int, ):
|
||||
width = height = 1024
|
||||
image = Image.new("RGB", (width, height), color=0)
|
||||
return {"image": image if num_images == 1 else [image] * num_images}
|
||||
|
||||
|
||||
def dummy_decoder_data_for_mllama(ctx: InputContext, seq_len: int,
|
||||
mm_counts: Mapping[str, int]):
|
||||
num_images = mm_counts["image"]
|
||||
return DummyData(dummy_decoder_seq_data(seq_len, num_images))
|
||||
|
||||
|
||||
def dummy_encoder_data_for_mllama(ctx: InputContext, seq_len: int,
|
||||
mm_counts: Mapping[str, int]):
|
||||
num_images = mm_counts["image"]
|
||||
return DummyData(dummy_encoder_seq_data(ctx, num_images),
|
||||
dummy_image(num_images))
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[image_token_id],
|
||||
replacement=get_replacement_mllama,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def _prepare_aspect_ratio_attention_mask(
|
||||
@ -1107,11 +1142,9 @@ class MllamaForCausalLM(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_mllama_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_decoder_data_for_mllama)
|
||||
@INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_mllama)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_mllama)
|
||||
@MULTIMODAL_REGISTRY.register_processor(MllamaMultiModalProcessor,
|
||||
info=MllamaProcessingInfo,
|
||||
dummy_inputs=MllamaDummyInputsBuilder)
|
||||
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
@ -1120,7 +1153,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
config: MllamaConfig = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.quant_config = quant_config
|
||||
self.vocab_size = config.text_config.vocab_size
|
||||
@ -1130,6 +1163,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
self.pad_token_id = \
|
||||
config.pad_token_id if config.pad_token_id is not None else -1
|
||||
self.image_size = config.vision_config.image_size
|
||||
self.image_token_id = config.image_token_index
|
||||
|
||||
self.vision_model = MllamaVisionModel(config.vision_config,
|
||||
quant_config,
|
||||
@ -1204,48 +1238,12 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
if pixel_values is not None:
|
||||
assert aspect_ratio_ids is not None
|
||||
assert aspect_ratio_mask is not None
|
||||
max_num_images = max([len(x[0]) for x in pixel_values])
|
||||
if max_num_images == 0:
|
||||
raise ValueError("No images provided.")
|
||||
max_num_tiles = max(
|
||||
max([len(x) for x in y[0]]) for y in pixel_values)
|
||||
device = next(self.multi_modal_projector.parameters()).device
|
||||
bsz = len(pixel_values)
|
||||
out_num_tiles = []
|
||||
out_images = torch.zeros(
|
||||
bsz,
|
||||
max_num_images,
|
||||
max_num_tiles,
|
||||
3,
|
||||
self.image_size,
|
||||
self.image_size,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
out_ar_ids = torch.ones(bsz,
|
||||
max_num_images,
|
||||
dtype=torch.int64,
|
||||
device=device)
|
||||
out_ar_mask = torch.zeros(bsz,
|
||||
max_num_images,
|
||||
max_num_tiles,
|
||||
dtype=torch.int64,
|
||||
device=device)
|
||||
for b in range(len(pixel_values)):
|
||||
_num_tiles = []
|
||||
for i in range(len(pixel_values[b][0])):
|
||||
img = pixel_values[b][0][i]
|
||||
out_images[b, i, :img.shape[0]] = img
|
||||
out_ar_ids[b, i] = aspect_ratio_ids[b][0][i]
|
||||
out_ar_mask[b, i] = aspect_ratio_mask[b][0][i]
|
||||
_num_tiles.append(img.shape[0])
|
||||
out_num_tiles.append(_num_tiles)
|
||||
|
||||
return MllamaImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=out_images,
|
||||
aspect_ratio_ids=out_ar_ids,
|
||||
aspect_ratio_mask=out_ar_mask,
|
||||
data=pixel_values,
|
||||
aspect_ratio_ids=aspect_ratio_ids,
|
||||
aspect_ratio_mask=aspect_ratio_mask,
|
||||
)
|
||||
|
||||
if image_embeds is not None:
|
||||
@ -1312,7 +1310,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
batch_token_ids.append(token_ids[start:start + seq_len])
|
||||
start += seq_len
|
||||
sparse_mask = [
|
||||
get_cross_attention_token_mask(t, MLLAMA_IMAGE_TOKEN_ID)
|
||||
get_cross_attention_token_mask(t, self.image_token_id)
|
||||
for t in batch_token_ids
|
||||
]
|
||||
|
||||
@ -1384,8 +1382,8 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
# block manager to allocate blocks for those images only.
|
||||
# See input_processor_for_mllama() for more details.
|
||||
num_tiles_tensor = kwargs.pop("num_tiles")
|
||||
num_tiles = [t[0].tolist() for t in num_tiles_tensor]
|
||||
num_tokens_per_tile = (self.image_size // 14)**2 + 1
|
||||
num_tiles = [t.tolist() for t in num_tiles_tensor]
|
||||
num_tokens_per_tile = calc_token_per_chunk(self.image_size)
|
||||
actual_encoder_seq_lens = [
|
||||
sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles
|
||||
]
|
||||
|
||||
@ -739,3 +739,19 @@ class MultiModalInputs(TypedDict):
|
||||
For each modality, information about the placeholder tokens in
|
||||
:code:`prompt_token_ids`.
|
||||
"""
|
||||
|
||||
|
||||
class MultiModalEncDecInputs(MultiModalInputs):
|
||||
"""
|
||||
Represents the outputs of :class:`vllm.multimodal.EncDecMultiModalProcessor`
|
||||
ready to be passed to vLLM internals.
|
||||
"""
|
||||
|
||||
encoder_prompt: str
|
||||
"""The processed encoder prompt text."""
|
||||
|
||||
encoder_prompt_token_ids: list[int]
|
||||
"""The processed token IDs of the encoder prompt."""
|
||||
|
||||
encoder_token_type_ids: NotRequired[list[int]]
|
||||
"""The token type IDs of the encoder prompt."""
|
||||
|
||||
@ -20,9 +20,9 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
|
||||
from vllm.utils import LRUCache, flatten_2d_lists, full_groupby
|
||||
|
||||
from .hasher import MultiModalHasher
|
||||
from .inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputs, MultiModalKwargs, MultiModalKwargsItem,
|
||||
PlaceholderRange)
|
||||
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
|
||||
MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs,
|
||||
MultiModalKwargsItem, PlaceholderRange)
|
||||
from .parse import MultiModalDataItems, MultiModalDataParser
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -1293,3 +1293,57 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
mm_hashes=mm_hashes,
|
||||
mm_placeholders=mm_placeholder_ranges,
|
||||
)
|
||||
|
||||
|
||||
class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
|
||||
@abstractmethod
|
||||
def create_encoder_prompt(
|
||||
self,
|
||||
prompt: Union[str, list[int]],
|
||||
mm_data: MultiModalDataDict,
|
||||
) -> Union[str, list[int]]:
|
||||
"""Create input prompt for the encoder."""
|
||||
raise NotImplementedError
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt: Union[str, list[int]],
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> MultiModalEncDecInputs:
|
||||
"""
|
||||
Process multi-modal inputs to be used in vLLM.
|
||||
The main processing steps are modified to fit encoder-decoder model:
|
||||
1. Create encoder prompt from input prompt text.
|
||||
2. Apply the HF processor on encoder prompt.
|
||||
3. Copy the input prompt text as decoder prompt inputs.
|
||||
"""
|
||||
encoder_prompt = self.create_encoder_prompt(prompt, mm_data)
|
||||
encoder_inputs = super().apply(
|
||||
encoder_prompt,
|
||||
mm_data,
|
||||
hf_processor_mm_kwargs,
|
||||
)
|
||||
|
||||
# We assumed the decoder prompt text is copied from
|
||||
# the original encoder prompt without extra process
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
if isinstance(prompt, str):
|
||||
decoder_prompt = prompt
|
||||
decoder_prompt_ids = encode_tokens(tokenizer,
|
||||
prompt,
|
||||
add_special_tokens=False)
|
||||
else:
|
||||
decoder_prompt = decode_tokens(tokenizer, prompt)
|
||||
decoder_prompt_ids = prompt
|
||||
|
||||
mm_inputs = MultiModalEncDecInputs(
|
||||
encoder_prompt=encoder_inputs["prompt"],
|
||||
encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"],
|
||||
**encoder_inputs)
|
||||
mm_inputs.update({
|
||||
"prompt": decoder_prompt,
|
||||
"prompt_token_ids": decoder_prompt_ids
|
||||
})
|
||||
return mm_inputs
|
||||
|
||||
@ -144,7 +144,11 @@ class MultiModalProfiler(Generic[_I]):
|
||||
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
|
||||
)
|
||||
|
||||
def get_dummy_data(self, seq_len: int) -> DummyData:
|
||||
def get_dummy_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
is_encoder_data: bool = False,
|
||||
) -> DummyData:
|
||||
# Avoid circular import
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
@ -183,16 +187,18 @@ class MultiModalProfiler(Generic[_I]):
|
||||
total_len = len(prompt_token_ids)
|
||||
|
||||
# V0 does not support chunked prefill.
|
||||
if total_len > seq_len and not envs.VLLM_USE_V1:
|
||||
logger.warning(
|
||||
"The context length (%d) of the model is too short "
|
||||
"to hold the multi-modal embeddings in the worst case "
|
||||
"(%d tokens in total, out of which %s are reserved for "
|
||||
"multi-modal embeddings). This may cause certain multi-modal "
|
||||
"inputs to fail during inference, even when the input text is "
|
||||
"short. To avoid this, you should increase `max_model_len`, "
|
||||
"reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len,
|
||||
total_len, total_placeholders_by_modality)
|
||||
if (total_len > seq_len and not envs.VLLM_USE_V1) or is_encoder_data:
|
||||
if total_len > seq_len:
|
||||
logger.warning(
|
||||
"The context length (%d) of the model is too short "
|
||||
"to hold the multi-modal embeddings in the worst case "
|
||||
"(%d tokens in total, out of which %s are reserved for "
|
||||
"multi-modal embeddings). This may cause certain "
|
||||
"multi-modal inputs to fail during inference, even when "
|
||||
"the input text is short. To avoid this, you should "
|
||||
"increase `max_model_len`, reduce `max_num_seqs`, "
|
||||
"and/or reduce `mm_counts`.", seq_len, total_len,
|
||||
total_placeholders_by_modality)
|
||||
|
||||
return DummyData(
|
||||
seq_data=SequenceData.from_prompt_token_counts((0, seq_len)),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user