mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:57:45 +08:00
[VLM] Implement merged multimodal processor for Mllama (#11427)
This commit is contained in:
parent
d88c8666a1
commit
bc55d13070
@ -7,11 +7,11 @@ import torch
|
|||||||
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
|
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
|
||||||
BatchEncoding)
|
BatchEncoding)
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
|
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
|
||||||
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
|
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
|
||||||
global_force_attn_backend_context_manager)
|
global_force_attn_backend_context_manager)
|
||||||
from vllm.model_executor.models.mllama import (MLLAMA_IMAGE_TOKEN_ID,
|
from vllm.model_executor.models.mllama import MllamaForConditionalGeneration
|
||||||
MllamaForConditionalGeneration)
|
|
||||||
from vllm.multimodal.image import rescale_image_size
|
from vllm.multimodal.image import rescale_image_size
|
||||||
from vllm.sequence import SampleLogprobs
|
from vllm.sequence import SampleLogprobs
|
||||||
|
|
||||||
@ -21,6 +21,7 @@ from ....utils import large_gpu_test
|
|||||||
from ...utils import check_logprobs_close
|
from ...utils import check_logprobs_close
|
||||||
|
|
||||||
_LIMIT_IMAGE_PER_PROMPT = 3
|
_LIMIT_IMAGE_PER_PROMPT = 3
|
||||||
|
MLLAMA_IMAGE_TOKEN_ID = 128256
|
||||||
|
|
||||||
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN]
|
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN]
|
||||||
|
|
||||||
@ -396,6 +397,64 @@ def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@large_gpu_test(min_gb=48)
|
||||||
|
@pytest.mark.core_model
|
||||||
|
@pytest.mark.parametrize("model", models)
|
||||||
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [32])
|
||||||
|
def test_explicit_implicit_prompt(
|
||||||
|
image_assets: _ImageAssets,
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
):
|
||||||
|
stop_sign = image_assets[0].pil_image
|
||||||
|
# yapf: disable
|
||||||
|
prompts = [
|
||||||
|
# explicit prompt
|
||||||
|
{
|
||||||
|
"encoder_prompt": {
|
||||||
|
"prompt": "<|image|>",
|
||||||
|
"multi_modal_data": {"image": stop_sign},
|
||||||
|
},
|
||||||
|
"decoder_prompt": {
|
||||||
|
"prompt_token_ids": [128000, 791, 2262, 315, 279, 2217, 220, 128256, 374], # noqa: E501
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"encoder_prompt": "Not <|image|>",
|
||||||
|
"decoder_prompt": "The color of the sky is blue but sometimes it can also be", # noqa: E501
|
||||||
|
},
|
||||||
|
# implicit prompt
|
||||||
|
{
|
||||||
|
"prompt": "<|begin_of_text|>The content of the image <|image|> is", # noqa: E501
|
||||||
|
"multi_modal_data": {"image": stop_sign},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"prompt": "The color of the sky is blue but sometimes it can also be", # noqa: E501
|
||||||
|
},
|
||||||
|
]
|
||||||
|
# yapf: enable
|
||||||
|
llm = LLM(
|
||||||
|
model=model,
|
||||||
|
dtype=dtype,
|
||||||
|
max_model_len=4096,
|
||||||
|
max_num_seqs=2,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
enforce_eager=True,
|
||||||
|
)
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
n_prompts = len(prompts)
|
||||||
|
explicit_outputs = outputs[:n_prompts // 2]
|
||||||
|
implicit_outputs = outputs[n_prompts // 2:]
|
||||||
|
for exp_output, imp_output in zip(explicit_outputs, implicit_outputs):
|
||||||
|
assert exp_output.outputs[0].text == imp_output.outputs[0].text
|
||||||
|
|
||||||
|
|
||||||
@large_gpu_test(min_gb=48)
|
@large_gpu_test(min_gb=48)
|
||||||
@pytest.mark.core_model
|
@pytest.mark.core_model
|
||||||
@pytest.mark.parametrize("model", models)
|
@pytest.mark.parametrize("model", models)
|
||||||
@ -458,6 +517,10 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens,
|
|||||||
images=images)
|
images=images)
|
||||||
|
|
||||||
|
|
||||||
|
class DummyModel:
|
||||||
|
image_token_id = MLLAMA_IMAGE_TOKEN_ID
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.core_model
|
@pytest.mark.core_model
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"input_indices_and_output",
|
"input_indices_and_output",
|
||||||
@ -499,7 +562,7 @@ def test_get_cross_attention_mask(input_indices_and_output) -> None:
|
|||||||
use_cuda_graph=False,
|
use_cuda_graph=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
dummy: dict[str, str] = {}
|
dummy = DummyModel()
|
||||||
|
|
||||||
cross_attention_mask, kv_range_for_decode = MllamaForConditionalGeneration\
|
cross_attention_mask, kv_range_for_decode = MllamaForConditionalGeneration\
|
||||||
.get_cross_attention_mask(dummy,
|
.get_cross_attention_mask(dummy,
|
||||||
@ -556,7 +619,7 @@ def test_get_full_text_row_masked_out_mask(input_indices) -> None:
|
|||||||
use_cuda_graph=False,
|
use_cuda_graph=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
dummy: dict[str, str] = {}
|
dummy = DummyModel()
|
||||||
|
|
||||||
full_text_row_masked_out_mask = MllamaForConditionalGeneration\
|
full_text_row_masked_out_mask = MllamaForConditionalGeneration\
|
||||||
.get_full_text_row_masked_out_mask(dummy,
|
.get_full_text_row_masked_out_mask(dummy,
|
||||||
|
|||||||
@ -85,6 +85,14 @@ def _test_processing_correctness(
|
|||||||
partial(random_audio, rng, min_len=512, max_len=1024, sr=16000),
|
partial(random_audio, rng, min_len=512, max_len=1024, sr=16000),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tokenizer_encode_kwargs = {}
|
||||||
|
if model_config.hf_config.model_type == "mllama":
|
||||||
|
# For Mllama, tokenizer will always add bos_token at the beginning of
|
||||||
|
# prompt by default, causing hf_processor outputs incorrect token ids.
|
||||||
|
# So we need use `add_special_tokens=False` here to leave bos_token
|
||||||
|
# to be added by the processor.
|
||||||
|
tokenizer_encode_kwargs = {"add_special_tokens": False}
|
||||||
|
|
||||||
for batch_idx in range(num_batches):
|
for batch_idx in range(num_batches):
|
||||||
mm_data = {
|
mm_data = {
|
||||||
k:
|
k:
|
||||||
@ -122,7 +130,7 @@ def _test_processing_correctness(
|
|||||||
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
|
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
|
||||||
|
|
||||||
baseline_tokenized_result = baseline_processor.apply(
|
baseline_tokenized_result = baseline_processor.apply(
|
||||||
tokenizer.encode(prompt),
|
tokenizer.encode(prompt, **tokenizer_encode_kwargs),
|
||||||
mm_data=mm_data,
|
mm_data=mm_data,
|
||||||
hf_processor_mm_kwargs={},
|
hf_processor_mm_kwargs={},
|
||||||
)
|
)
|
||||||
@ -131,7 +139,7 @@ def _test_processing_correctness(
|
|||||||
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
|
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
|
||||||
|
|
||||||
cached_tokenized_result = cached_processor.apply(
|
cached_tokenized_result = cached_processor.apply(
|
||||||
tokenizer.encode(prompt),
|
tokenizer.encode(prompt, **tokenizer_encode_kwargs),
|
||||||
mm_data=mm_data,
|
mm_data=mm_data,
|
||||||
hf_processor_mm_kwargs={},
|
hf_processor_mm_kwargs={},
|
||||||
)
|
)
|
||||||
@ -155,6 +163,7 @@ def _test_processing_correctness(
|
|||||||
"llava-hf/llava-v1.6-mistral-7b-hf",
|
"llava-hf/llava-v1.6-mistral-7b-hf",
|
||||||
"llava-hf/LLaVA-NeXT-Video-7B-hf",
|
"llava-hf/LLaVA-NeXT-Video-7B-hf",
|
||||||
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
|
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
|
||||||
|
"meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||||
"TIGER-Lab/Mantis-8B-siglip-llama3",
|
"TIGER-Lab/Mantis-8B-siglip-llama3",
|
||||||
"mistral-community/pixtral-12b",
|
"mistral-community/pixtral-12b",
|
||||||
"openbmb/MiniCPM-o-2_6",
|
"openbmb/MiniCPM-o-2_6",
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import List, Mapping, Optional, Union
|
from typing import List, Mapping, Optional, Tuple, Union, cast
|
||||||
|
|
||||||
from typing_extensions import assert_never
|
from typing_extensions import assert_never
|
||||||
|
|
||||||
@ -9,7 +9,8 @@ from vllm.config import ModelConfig
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||||
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputs
|
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
|
||||||
|
MultiModalInputs)
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
||||||
|
|
||||||
@ -495,6 +496,51 @@ class InputPreprocessor:
|
|||||||
decoder=decoder_inputs,
|
decoder=decoder_inputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||||
|
self,
|
||||||
|
inputs: SingletonInputs,
|
||||||
|
decoder_inputs_to_override: Optional[SingletonInputs] = None,
|
||||||
|
) -> Tuple[SingletonInputs, SingletonInputs]:
|
||||||
|
"""
|
||||||
|
For encoder/decoder models only:
|
||||||
|
Separate Encoder/Decoder inputs from a MultiModalEncDecInputs
|
||||||
|
"""
|
||||||
|
encoder_inputs: SingletonInputs
|
||||||
|
decoder_inputs: SingletonInputs
|
||||||
|
if inputs["type"] == "multimodal":
|
||||||
|
# Multimodal data inputs
|
||||||
|
assert ("encoder_prompt" in inputs
|
||||||
|
and "encoder_prompt_token_ids" in inputs)
|
||||||
|
inputs = cast(MultiModalEncDecInputs, inputs)
|
||||||
|
encoder_inputs = token_inputs(
|
||||||
|
prompt=inputs["encoder_prompt"],
|
||||||
|
prompt_token_ids=inputs["encoder_prompt_token_ids"],
|
||||||
|
)
|
||||||
|
if decoder_inputs_to_override is not None:
|
||||||
|
decoder_inputs = MultiModalInputs(
|
||||||
|
type="multimodal",
|
||||||
|
prompt=decoder_inputs_to_override.get("prompt", ""),
|
||||||
|
prompt_token_ids=decoder_inputs_to_override[
|
||||||
|
"prompt_token_ids"],
|
||||||
|
mm_kwargs=inputs["mm_kwargs"],
|
||||||
|
mm_placeholders=inputs["mm_placeholders"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
decoder_inputs = MultiModalInputs(
|
||||||
|
type="multimodal",
|
||||||
|
prompt=inputs["prompt"],
|
||||||
|
prompt_token_ids=inputs["prompt_token_ids"],
|
||||||
|
mm_kwargs=inputs["mm_kwargs"],
|
||||||
|
mm_placeholders=inputs["mm_placeholders"],
|
||||||
|
)
|
||||||
|
elif inputs["type"] == "token":
|
||||||
|
# Text-only inputs
|
||||||
|
encoder_inputs = token_inputs(prompt="", prompt_token_ids=[])
|
||||||
|
decoder_inputs = decoder_inputs_to_override or inputs
|
||||||
|
else:
|
||||||
|
assert_never(inputs) # type: ignore[arg-type]
|
||||||
|
return encoder_inputs, decoder_inputs
|
||||||
|
|
||||||
def _process_encoder_decoder_prompt(
|
def _process_encoder_decoder_prompt(
|
||||||
self,
|
self,
|
||||||
prompt: PromptType,
|
prompt: PromptType,
|
||||||
@ -539,7 +585,6 @@ class InputPreprocessor:
|
|||||||
prompt["encoder_prompt"],
|
prompt["encoder_prompt"],
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (decoder_input := prompt["decoder_prompt"]) is None:
|
if (decoder_input := prompt["decoder_prompt"]) is None:
|
||||||
decoder_inputs = None
|
decoder_inputs = None
|
||||||
else:
|
else:
|
||||||
@ -547,11 +592,26 @@ class InputPreprocessor:
|
|||||||
decoder_input,
|
decoder_input,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
)
|
)
|
||||||
|
# For multimodal model, override decoder prompt from processor
|
||||||
|
# with explicit decoder prompt.
|
||||||
|
if self.model_config.is_multimodal_model and (
|
||||||
|
self._can_process_multimodal()):
|
||||||
|
encoder_inputs, decoder_inputs = (
|
||||||
|
self._separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||||
|
encoder_inputs, decoder_inputs))
|
||||||
else:
|
else:
|
||||||
encoder_inputs = self._prompt_to_llm_inputs(
|
inputs = self._prompt_to_llm_inputs(
|
||||||
prompt,
|
prompt,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
)
|
)
|
||||||
|
if self.model_config.is_multimodal_model and (
|
||||||
|
self._can_process_multimodal()):
|
||||||
|
# Encoder-Decoder Multimodal model
|
||||||
|
encoder_inputs, decoder_inputs = (
|
||||||
|
self._separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||||
|
inputs))
|
||||||
|
else:
|
||||||
|
encoder_inputs = inputs
|
||||||
|
|
||||||
decoder_inputs = None
|
decoder_inputs = None
|
||||||
|
|
||||||
@ -583,11 +643,27 @@ class InputPreprocessor:
|
|||||||
|
|
||||||
encoder_inputs, decoder_inputs = await asyncio.gather(
|
encoder_inputs, decoder_inputs = await asyncio.gather(
|
||||||
encoder_task, decoder_task)
|
encoder_task, decoder_task)
|
||||||
|
|
||||||
|
# For multimodal model, override decoder prompt from processor
|
||||||
|
# with explicit decoder prompt.
|
||||||
|
if self.model_config.is_multimodal_model and (
|
||||||
|
self._can_process_multimodal()):
|
||||||
|
encoder_inputs, decoder_inputs = (
|
||||||
|
self._separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||||
|
encoder_inputs, decoder_inputs))
|
||||||
else:
|
else:
|
||||||
encoder_inputs = await self._prompt_to_llm_inputs_async(
|
inputs = await self._prompt_to_llm_inputs_async(
|
||||||
prompt,
|
prompt,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
)
|
)
|
||||||
|
if self.model_config.is_multimodal_model and (
|
||||||
|
self._can_process_multimodal()):
|
||||||
|
# Encoder-Decoder Multimodal model
|
||||||
|
encoder_inputs, decoder_inputs = (
|
||||||
|
self._separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||||
|
inputs))
|
||||||
|
else:
|
||||||
|
encoder_inputs = inputs
|
||||||
|
|
||||||
decoder_inputs = None
|
decoder_inputs = None
|
||||||
|
|
||||||
|
|||||||
@ -350,7 +350,8 @@ class InputRegistry:
|
|||||||
)
|
)
|
||||||
processor = mm_registry.create_processor(model_config, tokenizer)
|
processor = mm_registry.create_processor(model_config, tokenizer)
|
||||||
profiler = MultiModalProfiler(processor)
|
profiler = MultiModalProfiler(processor)
|
||||||
dummy_data = profiler.get_dummy_data(seq_len)
|
dummy_data = profiler.get_dummy_data(
|
||||||
|
seq_len, is_encoder_data=is_encoder_data)
|
||||||
else:
|
else:
|
||||||
model_cls, _ = get_model_architecture(model_config)
|
model_cls, _ = get_model_architecture(model_config)
|
||||||
if is_encoder_data:
|
if is_encoder_data:
|
||||||
|
|||||||
@ -23,14 +23,15 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
import transformers.models.mllama.configuration_mllama as config_mllama
|
import transformers.models.mllama.configuration_mllama as config_mllama
|
||||||
from PIL import Image
|
from PIL.Image import Image
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from transformers import BatchFeature, MllamaConfig
|
||||||
from transformers.modeling_outputs import (BaseModelOutput,
|
from transformers.modeling_outputs import (BaseModelOutput,
|
||||||
CausalLMOutputWithPast)
|
CausalLMOutputWithPast)
|
||||||
from transformers.models.mllama.image_processing_mllama import (
|
from transformers.models.mllama.image_processing_mllama import (
|
||||||
get_optimal_tiled_canvas)
|
get_optimal_tiled_canvas)
|
||||||
from transformers.models.mllama.processing_mllama import (
|
from transformers.models.mllama.processing_mllama import (
|
||||||
get_cross_attention_token_mask)
|
MllamaProcessor, get_cross_attention_token_mask)
|
||||||
|
|
||||||
import vllm.distributed.parallel_state as ps
|
import vllm.distributed.parallel_state as ps
|
||||||
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
||||||
@ -38,8 +39,6 @@ from vllm.attention.ops.paged_attn import PagedAttention
|
|||||||
from vllm.attention.selector import _Backend
|
from vllm.attention.selector import _Backend
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.inputs import (INPUT_REGISTRY, DummyData, EncoderDecoderInputs,
|
|
||||||
InputContext, TokenInputs, token_inputs)
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
@ -54,8 +53,13 @@ from vllm.model_executor.model_loader.weight_utils import (
|
|||||||
default_weight_loader, maybe_remap_kv_scale_name)
|
default_weight_loader, maybe_remap_kv_scale_name)
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.sequence import SequenceData
|
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
|
||||||
from vllm.utils import is_list_of
|
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
||||||
|
MultiModalDataDict, MultiModalDataItems)
|
||||||
|
from vllm.multimodal.processing import (BaseProcessingInfo,
|
||||||
|
EncDecMultiModalProcessor,
|
||||||
|
PromptReplacement)
|
||||||
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||||
|
|
||||||
from .clip import CLIPMLP
|
from .clip import CLIPMLP
|
||||||
from .interfaces import SupportsMultiModal
|
from .interfaces import SupportsMultiModal
|
||||||
@ -63,8 +67,6 @@ from .llama import LlamaDecoderLayer, LlamaMLP
|
|||||||
from .utils import maybe_prefix
|
from .utils import maybe_prefix
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
MLLAMA_IMAGE_TOKEN_ID = 128256
|
|
||||||
MLLAMA_IMAGE_TOKEN = "<|image|>"
|
|
||||||
|
|
||||||
|
|
||||||
class MllamaImagePixelInputs(TypedDict):
|
class MllamaImagePixelInputs(TypedDict):
|
||||||
@ -81,25 +83,113 @@ class MllamaImagePixelInputs(TypedDict):
|
|||||||
# TODO: support LlamaImageEmbeddingInputs
|
# TODO: support LlamaImageEmbeddingInputs
|
||||||
|
|
||||||
|
|
||||||
def _get_num_image_in_last_group(prompt_token_ids: List[int]) -> int:
|
def calc_token_per_chunk(image_size: int) -> int:
|
||||||
num_images = 0
|
assert image_size % 14 == 0, "chunk size should be multiple of 14"
|
||||||
for token_id in prompt_token_ids[::-1]:
|
token_per_chunk = (image_size // 14)**2 + 1
|
||||||
if token_id == MLLAMA_IMAGE_TOKEN_ID:
|
return token_per_chunk
|
||||||
num_images += 1
|
|
||||||
elif num_images > 0:
|
|
||||||
break
|
|
||||||
return num_images
|
|
||||||
|
|
||||||
|
|
||||||
def input_processor_for_mllama(
|
class MllamaProcessingInfo(BaseProcessingInfo):
|
||||||
ctx: InputContext,
|
|
||||||
inputs: EncoderDecoderInputs,
|
def get_hf_config(self) -> MllamaConfig:
|
||||||
) -> EncoderDecoderInputs:
|
return self.ctx.get_hf_config(MllamaConfig)
|
||||||
# Example input to processor:
|
|
||||||
|
def get_hf_processor(self) -> MllamaProcessor:
|
||||||
|
return self.ctx.get_hf_processor(MllamaProcessor)
|
||||||
|
|
||||||
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||||
|
return {"image": None}
|
||||||
|
|
||||||
|
def get_token_per_chunk_from_config(self) -> int:
|
||||||
|
image_size = self.get_hf_config().vision_config.image_size
|
||||||
|
return calc_token_per_chunk(image_size)
|
||||||
|
|
||||||
|
def get_mm_max_tokens_per_item(
|
||||||
|
self,
|
||||||
|
seq_len: int,
|
||||||
|
mm_counts: Mapping[str, int],
|
||||||
|
) -> Mapping[str, int]:
|
||||||
|
vision_config = self.get_hf_config().vision_config
|
||||||
|
token_per_chunk = self.get_token_per_chunk_from_config()
|
||||||
|
mm_max_tokens = vision_config.max_num_tiles * token_per_chunk
|
||||||
|
return {"image": mm_max_tokens}
|
||||||
|
|
||||||
|
def get_num_tiles_per_image(self, image_height: int,
|
||||||
|
image_width: int) -> int:
|
||||||
|
vision_config = self.get_hf_config().vision_config
|
||||||
|
max_num_tiles = vision_config.max_num_tiles
|
||||||
|
image_size = vision_config.image_size
|
||||||
|
tiled_height, tiled_width = get_optimal_tiled_canvas(
|
||||||
|
image_height,
|
||||||
|
image_width,
|
||||||
|
max_num_tiles,
|
||||||
|
tile_size=image_size,
|
||||||
|
)
|
||||||
|
num_tiles_height = tiled_height // image_size
|
||||||
|
num_tiles_width = tiled_width // image_size
|
||||||
|
return num_tiles_height * num_tiles_width
|
||||||
|
|
||||||
|
def get_image_size_with_most_features(self) -> ImageSize:
|
||||||
|
vision_config = self.get_hf_config().vision_config
|
||||||
|
image_size = vision_config.image_size
|
||||||
|
max_num_tiles = vision_config.max_num_tiles
|
||||||
|
# Result in the max possible feature size (h:w = 16:1)
|
||||||
|
return ImageSize(height=max_num_tiles * image_size, width=image_size)
|
||||||
|
|
||||||
|
|
||||||
|
class MllamaDummyInputsBuilder(BaseDummyInputsBuilder[MllamaProcessingInfo]):
|
||||||
|
|
||||||
|
def get_dummy_processor_inputs(
|
||||||
|
self,
|
||||||
|
seq_len: int,
|
||||||
|
mm_counts: Mapping[str, int],
|
||||||
|
) -> ProcessorInputs:
|
||||||
|
num_images = mm_counts.get("image", 0)
|
||||||
|
|
||||||
|
target_width, target_height = \
|
||||||
|
self.info.get_image_size_with_most_features()
|
||||||
|
|
||||||
|
mm_data = {
|
||||||
|
"image":
|
||||||
|
self._get_dummy_images(width=target_width,
|
||||||
|
height=target_height,
|
||||||
|
num_images=num_images)
|
||||||
|
}
|
||||||
|
|
||||||
|
hf_processor = self.info.get_hf_processor()
|
||||||
|
image_token: str = hf_processor.image_token
|
||||||
|
|
||||||
|
return ProcessorInputs(
|
||||||
|
prompt_text=image_token * num_images,
|
||||||
|
mm_data=mm_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
|
||||||
|
):
|
||||||
|
|
||||||
|
def _call_hf_processor(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
mm_data: Mapping[str, object],
|
||||||
|
mm_kwargs: Mapping[str, object],
|
||||||
|
) -> BatchFeature:
|
||||||
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
if mm_data:
|
||||||
|
num_tiles = [
|
||||||
|
self.info.get_num_tiles_per_image(img.height, img.width)
|
||||||
|
for img in mm_data["images"]
|
||||||
|
]
|
||||||
|
processed_outputs = super()._call_hf_processor(
|
||||||
|
prompt, mm_data, mm_kwargs)
|
||||||
|
processed_outputs["num_tiles"] = torch.tensor(num_tiles)
|
||||||
|
for k in ('pixel_values', 'aspect_ratio_ids', "aspect_ratio_mask"):
|
||||||
|
processed_outputs[k] = processed_outputs[k].squeeze(0)
|
||||||
|
# Example input to encoder and decoder:
|
||||||
# {
|
# {
|
||||||
# 'encoder': {
|
# 'encoder': {
|
||||||
# 'type': 'token',
|
# 'type': 'token',
|
||||||
# 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
|
# 'prompt_token_ids': [128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
|
||||||
# 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501
|
# 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501
|
||||||
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
|
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
|
||||||
# },
|
# },
|
||||||
@ -108,131 +198,76 @@ def input_processor_for_mllama(
|
|||||||
# 'prompt_token_ids': [128000],
|
# 'prompt_token_ids': [128000],
|
||||||
# },
|
# },
|
||||||
# }
|
# }
|
||||||
|
processed_token_ids = processed_outputs.pop("input_ids")
|
||||||
|
start_idx, end_idx = 0, processed_token_ids.size(1)
|
||||||
|
processed_prompt_text = tokenizer.decode(processed_token_ids[0])
|
||||||
|
|
||||||
# move encoder prompt to decoder
|
hf_processor = self.info.get_hf_processor()
|
||||||
dec_inputs = TokenInputs(**inputs["encoder"])
|
bos_token = hf_processor.bos_token
|
||||||
|
# Remove the bos_token from the start of prompt,
|
||||||
|
# because we all know there would be image_token.
|
||||||
|
if processed_prompt_text.startswith(bos_token):
|
||||||
|
start_idx += 1
|
||||||
|
# Remove the bos_token from the end of prompt,
|
||||||
|
# because text is empty in this case.
|
||||||
|
if processed_prompt_text.endswith(bos_token):
|
||||||
|
end_idx -= 1
|
||||||
|
processed_outputs[
|
||||||
|
"input_ids"] = processed_token_ids[:, start_idx:end_idx]
|
||||||
|
else:
|
||||||
|
processed_outputs = tokenizer(prompt,
|
||||||
|
add_special_tokens=False,
|
||||||
|
return_tensors="pt")
|
||||||
|
return processed_outputs
|
||||||
|
|
||||||
multi_modal_data = dec_inputs.get("multi_modal_data")
|
def _get_mm_fields_config(
|
||||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
self,
|
||||||
# text-only
|
hf_inputs: BatchFeature,
|
||||||
return EncoderDecoderInputs(
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
encoder=token_inputs([]),
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
decoder=dec_inputs,
|
return dict(
|
||||||
|
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||||
|
aspect_ratio_ids=MultiModalFieldConfig.batched("image"),
|
||||||
|
aspect_ratio_mask=MultiModalFieldConfig.batched("image"),
|
||||||
|
num_tiles=MultiModalFieldConfig.batched("image"),
|
||||||
)
|
)
|
||||||
|
|
||||||
image_data = multi_modal_data["image"]
|
def create_encoder_prompt(
|
||||||
if isinstance(image_data, Image.Image):
|
self,
|
||||||
image_data = [image_data]
|
prompt: Union[str, list[int]],
|
||||||
|
mm_data: MultiModalDataDict,
|
||||||
|
) -> Union[str, list[int]]:
|
||||||
|
data = mm_data.get("image", [])
|
||||||
|
num_images = 1 if isinstance(data, Image) else len(data)
|
||||||
|
image_token_id = self.info.get_hf_config().image_token_index
|
||||||
|
return [image_token_id] * num_images
|
||||||
|
|
||||||
assert is_list_of(image_data, Image.Image)
|
def _get_prompt_replacements(
|
||||||
|
self,
|
||||||
|
mm_items: MultiModalDataItems,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
out_mm_kwargs: MultiModalKwargs,
|
||||||
|
) -> list[PromptReplacement]:
|
||||||
|
token_per_chunk = self.info.get_token_per_chunk_from_config()
|
||||||
|
image_token_id = self.info.get_hf_config().image_token_index
|
||||||
|
|
||||||
num_image_tokens = dec_inputs['prompt_token_ids'].count(
|
def get_replacement_mllama(item_idx):
|
||||||
MLLAMA_IMAGE_TOKEN_ID)
|
images = mm_items.get_items("image", ImageProcessorItems)
|
||||||
if num_image_tokens != len(image_data):
|
image_size = images.get_image_size(item_idx)
|
||||||
raise ValueError(
|
num_tile = self.info.get_num_tiles_per_image(
|
||||||
f"The number of image tokens ({num_image_tokens}) must be"
|
image_height=image_size.height,
|
||||||
f" the same as the number of images ({len(image_data)})")
|
image_width=image_size.width,
|
||||||
|
|
||||||
# Since only the last group of consecutive images
|
|
||||||
# are attended by the decoded tokens, we only need to
|
|
||||||
# get the number of tiles for those images.
|
|
||||||
num_decode_images = _get_num_image_in_last_group(
|
|
||||||
dec_inputs["prompt_token_ids"])
|
|
||||||
|
|
||||||
hf_config = ctx.model_config.hf_config
|
|
||||||
vision_config = hf_config.vision_config
|
|
||||||
|
|
||||||
num_tiles = 0
|
|
||||||
for image in image_data[::-1]:
|
|
||||||
width, height = image.size
|
|
||||||
tile_size = vision_config.image_size
|
|
||||||
canvas_height, canvas_width = get_optimal_tiled_canvas(
|
|
||||||
image_height=height,
|
|
||||||
image_width=width,
|
|
||||||
max_image_tiles=vision_config.max_num_tiles,
|
|
||||||
tile_size=tile_size,
|
|
||||||
)
|
)
|
||||||
num_tiles_height = canvas_height // tile_size
|
num_tokens = num_tile * token_per_chunk
|
||||||
num_tiles_width = canvas_width // tile_size
|
return [image_token_id] * num_tokens
|
||||||
num_tiles += num_tiles_height * num_tiles_width
|
|
||||||
num_decode_images -= 1
|
|
||||||
if num_decode_images == 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
# Set encoder prompt length based on the number of tiles.
|
return [
|
||||||
# This tells the block manager to allocate correct number
|
PromptReplacement(
|
||||||
# of slots for encoder tokens.
|
modality="image",
|
||||||
assert vision_config.image_size % 14 == 0, \
|
target=[image_token_id],
|
||||||
"chunk size should be multiple of 14"
|
replacement=get_replacement_mllama,
|
||||||
token_per_chunk = (vision_config.image_size // 14)**2 + 1
|
|
||||||
num_tokens = num_tiles * token_per_chunk
|
|
||||||
|
|
||||||
# Example output from processor:
|
|
||||||
# {
|
|
||||||
# 'encoder': {
|
|
||||||
# 'type': 'token',
|
|
||||||
# 'prompt_token_ids': [128256, 128256, ..., 128256],
|
|
||||||
# 'prompt': '<|image|><|image|>...<|image|>',
|
|
||||||
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
|
|
||||||
# },
|
|
||||||
# 'decoder': {
|
|
||||||
# 'type': 'token',
|
|
||||||
# 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
|
|
||||||
# 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501
|
|
||||||
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
|
|
||||||
# },
|
|
||||||
# }
|
|
||||||
return EncoderDecoderInputs(
|
|
||||||
encoder=token_inputs(
|
|
||||||
prompt_token_ids=[MLLAMA_IMAGE_TOKEN_ID] * num_tokens,
|
|
||||||
prompt=MLLAMA_IMAGE_TOKEN * num_tokens,
|
|
||||||
multi_modal_data=multi_modal_data,
|
|
||||||
),
|
|
||||||
decoder=dec_inputs,
|
|
||||||
)
|
)
|
||||||
|
]
|
||||||
|
|
||||||
def get_max_mllama_image_tokens(ctx: InputContext) -> int:
|
|
||||||
hf_config = ctx.model_config.hf_config
|
|
||||||
token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1
|
|
||||||
return hf_config.vision_config.max_num_tiles * token_per_chunk
|
|
||||||
|
|
||||||
|
|
||||||
def dummy_decoder_seq_data(seq_len: int, num_images: int):
|
|
||||||
# <|image|> * num_images + 0 * (seq_len - num_images)
|
|
||||||
assert seq_len >= num_images, \
|
|
||||||
"seq_len should be greater than or equal to num_images"
|
|
||||||
|
|
||||||
return SequenceData.from_prompt_token_counts(
|
|
||||||
(MLLAMA_IMAGE_TOKEN_ID, num_images),
|
|
||||||
(0, seq_len - num_images),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def dummy_encoder_seq_data(ctx: InputContext, num_images: int):
|
|
||||||
num_tokens = get_max_mllama_image_tokens(ctx) * num_images
|
|
||||||
|
|
||||||
return SequenceData.from_prompt_token_counts(
|
|
||||||
(MLLAMA_IMAGE_TOKEN_ID, num_tokens))
|
|
||||||
|
|
||||||
|
|
||||||
def dummy_image(num_images: int, ):
|
|
||||||
width = height = 1024
|
|
||||||
image = Image.new("RGB", (width, height), color=0)
|
|
||||||
return {"image": image if num_images == 1 else [image] * num_images}
|
|
||||||
|
|
||||||
|
|
||||||
def dummy_decoder_data_for_mllama(ctx: InputContext, seq_len: int,
|
|
||||||
mm_counts: Mapping[str, int]):
|
|
||||||
num_images = mm_counts["image"]
|
|
||||||
return DummyData(dummy_decoder_seq_data(seq_len, num_images))
|
|
||||||
|
|
||||||
|
|
||||||
def dummy_encoder_data_for_mllama(ctx: InputContext, seq_len: int,
|
|
||||||
mm_counts: Mapping[str, int]):
|
|
||||||
num_images = mm_counts["image"]
|
|
||||||
return DummyData(dummy_encoder_seq_data(ctx, num_images),
|
|
||||||
dummy_image(num_images))
|
|
||||||
|
|
||||||
|
|
||||||
def _prepare_aspect_ratio_attention_mask(
|
def _prepare_aspect_ratio_attention_mask(
|
||||||
@ -1107,11 +1142,9 @@ class MllamaForCausalLM(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
@MULTIMODAL_REGISTRY.register_processor(MllamaMultiModalProcessor,
|
||||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_mllama_image_tokens)
|
info=MllamaProcessingInfo,
|
||||||
@INPUT_REGISTRY.register_dummy_data(dummy_decoder_data_for_mllama)
|
dummy_inputs=MllamaDummyInputsBuilder)
|
||||||
@INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_mllama)
|
|
||||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_mllama)
|
|
||||||
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||||
@ -1120,7 +1153,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = vllm_config.model_config.hf_config
|
config: MllamaConfig = vllm_config.model_config.hf_config
|
||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.vocab_size = config.text_config.vocab_size
|
self.vocab_size = config.text_config.vocab_size
|
||||||
@ -1130,6 +1163,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
self.pad_token_id = \
|
self.pad_token_id = \
|
||||||
config.pad_token_id if config.pad_token_id is not None else -1
|
config.pad_token_id if config.pad_token_id is not None else -1
|
||||||
self.image_size = config.vision_config.image_size
|
self.image_size = config.vision_config.image_size
|
||||||
|
self.image_token_id = config.image_token_index
|
||||||
|
|
||||||
self.vision_model = MllamaVisionModel(config.vision_config,
|
self.vision_model = MllamaVisionModel(config.vision_config,
|
||||||
quant_config,
|
quant_config,
|
||||||
@ -1204,48 +1238,12 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
assert aspect_ratio_ids is not None
|
assert aspect_ratio_ids is not None
|
||||||
assert aspect_ratio_mask is not None
|
assert aspect_ratio_mask is not None
|
||||||
max_num_images = max([len(x[0]) for x in pixel_values])
|
|
||||||
if max_num_images == 0:
|
|
||||||
raise ValueError("No images provided.")
|
|
||||||
max_num_tiles = max(
|
|
||||||
max([len(x) for x in y[0]]) for y in pixel_values)
|
|
||||||
device = next(self.multi_modal_projector.parameters()).device
|
|
||||||
bsz = len(pixel_values)
|
|
||||||
out_num_tiles = []
|
|
||||||
out_images = torch.zeros(
|
|
||||||
bsz,
|
|
||||||
max_num_images,
|
|
||||||
max_num_tiles,
|
|
||||||
3,
|
|
||||||
self.image_size,
|
|
||||||
self.image_size,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
out_ar_ids = torch.ones(bsz,
|
|
||||||
max_num_images,
|
|
||||||
dtype=torch.int64,
|
|
||||||
device=device)
|
|
||||||
out_ar_mask = torch.zeros(bsz,
|
|
||||||
max_num_images,
|
|
||||||
max_num_tiles,
|
|
||||||
dtype=torch.int64,
|
|
||||||
device=device)
|
|
||||||
for b in range(len(pixel_values)):
|
|
||||||
_num_tiles = []
|
|
||||||
for i in range(len(pixel_values[b][0])):
|
|
||||||
img = pixel_values[b][0][i]
|
|
||||||
out_images[b, i, :img.shape[0]] = img
|
|
||||||
out_ar_ids[b, i] = aspect_ratio_ids[b][0][i]
|
|
||||||
out_ar_mask[b, i] = aspect_ratio_mask[b][0][i]
|
|
||||||
_num_tiles.append(img.shape[0])
|
|
||||||
out_num_tiles.append(_num_tiles)
|
|
||||||
|
|
||||||
return MllamaImagePixelInputs(
|
return MllamaImagePixelInputs(
|
||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
data=out_images,
|
data=pixel_values,
|
||||||
aspect_ratio_ids=out_ar_ids,
|
aspect_ratio_ids=aspect_ratio_ids,
|
||||||
aspect_ratio_mask=out_ar_mask,
|
aspect_ratio_mask=aspect_ratio_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
if image_embeds is not None:
|
if image_embeds is not None:
|
||||||
@ -1312,7 +1310,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
batch_token_ids.append(token_ids[start:start + seq_len])
|
batch_token_ids.append(token_ids[start:start + seq_len])
|
||||||
start += seq_len
|
start += seq_len
|
||||||
sparse_mask = [
|
sparse_mask = [
|
||||||
get_cross_attention_token_mask(t, MLLAMA_IMAGE_TOKEN_ID)
|
get_cross_attention_token_mask(t, self.image_token_id)
|
||||||
for t in batch_token_ids
|
for t in batch_token_ids
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -1384,8 +1382,8 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
# block manager to allocate blocks for those images only.
|
# block manager to allocate blocks for those images only.
|
||||||
# See input_processor_for_mllama() for more details.
|
# See input_processor_for_mllama() for more details.
|
||||||
num_tiles_tensor = kwargs.pop("num_tiles")
|
num_tiles_tensor = kwargs.pop("num_tiles")
|
||||||
num_tiles = [t[0].tolist() for t in num_tiles_tensor]
|
num_tiles = [t.tolist() for t in num_tiles_tensor]
|
||||||
num_tokens_per_tile = (self.image_size // 14)**2 + 1
|
num_tokens_per_tile = calc_token_per_chunk(self.image_size)
|
||||||
actual_encoder_seq_lens = [
|
actual_encoder_seq_lens = [
|
||||||
sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles
|
sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles
|
||||||
]
|
]
|
||||||
|
|||||||
@ -739,3 +739,19 @@ class MultiModalInputs(TypedDict):
|
|||||||
For each modality, information about the placeholder tokens in
|
For each modality, information about the placeholder tokens in
|
||||||
:code:`prompt_token_ids`.
|
:code:`prompt_token_ids`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class MultiModalEncDecInputs(MultiModalInputs):
|
||||||
|
"""
|
||||||
|
Represents the outputs of :class:`vllm.multimodal.EncDecMultiModalProcessor`
|
||||||
|
ready to be passed to vLLM internals.
|
||||||
|
"""
|
||||||
|
|
||||||
|
encoder_prompt: str
|
||||||
|
"""The processed encoder prompt text."""
|
||||||
|
|
||||||
|
encoder_prompt_token_ids: list[int]
|
||||||
|
"""The processed token IDs of the encoder prompt."""
|
||||||
|
|
||||||
|
encoder_token_type_ids: NotRequired[list[int]]
|
||||||
|
"""The token type IDs of the encoder prompt."""
|
||||||
|
|||||||
@ -20,9 +20,9 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
|
|||||||
from vllm.utils import LRUCache, flatten_2d_lists, full_groupby
|
from vllm.utils import LRUCache, flatten_2d_lists, full_groupby
|
||||||
|
|
||||||
from .hasher import MultiModalHasher
|
from .hasher import MultiModalHasher
|
||||||
from .inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
|
||||||
MultiModalInputs, MultiModalKwargs, MultiModalKwargsItem,
|
MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs,
|
||||||
PlaceholderRange)
|
MultiModalKwargsItem, PlaceholderRange)
|
||||||
from .parse import MultiModalDataItems, MultiModalDataParser
|
from .parse import MultiModalDataItems, MultiModalDataParser
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -1293,3 +1293,57 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
mm_hashes=mm_hashes,
|
mm_hashes=mm_hashes,
|
||||||
mm_placeholders=mm_placeholder_ranges,
|
mm_placeholders=mm_placeholder_ranges,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create_encoder_prompt(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, list[int]],
|
||||||
|
mm_data: MultiModalDataDict,
|
||||||
|
) -> Union[str, list[int]]:
|
||||||
|
"""Create input prompt for the encoder."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, list[int]],
|
||||||
|
mm_data: MultiModalDataDict,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
) -> MultiModalEncDecInputs:
|
||||||
|
"""
|
||||||
|
Process multi-modal inputs to be used in vLLM.
|
||||||
|
The main processing steps are modified to fit encoder-decoder model:
|
||||||
|
1. Create encoder prompt from input prompt text.
|
||||||
|
2. Apply the HF processor on encoder prompt.
|
||||||
|
3. Copy the input prompt text as decoder prompt inputs.
|
||||||
|
"""
|
||||||
|
encoder_prompt = self.create_encoder_prompt(prompt, mm_data)
|
||||||
|
encoder_inputs = super().apply(
|
||||||
|
encoder_prompt,
|
||||||
|
mm_data,
|
||||||
|
hf_processor_mm_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# We assumed the decoder prompt text is copied from
|
||||||
|
# the original encoder prompt without extra process
|
||||||
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
if isinstance(prompt, str):
|
||||||
|
decoder_prompt = prompt
|
||||||
|
decoder_prompt_ids = encode_tokens(tokenizer,
|
||||||
|
prompt,
|
||||||
|
add_special_tokens=False)
|
||||||
|
else:
|
||||||
|
decoder_prompt = decode_tokens(tokenizer, prompt)
|
||||||
|
decoder_prompt_ids = prompt
|
||||||
|
|
||||||
|
mm_inputs = MultiModalEncDecInputs(
|
||||||
|
encoder_prompt=encoder_inputs["prompt"],
|
||||||
|
encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"],
|
||||||
|
**encoder_inputs)
|
||||||
|
mm_inputs.update({
|
||||||
|
"prompt": decoder_prompt,
|
||||||
|
"prompt_token_ids": decoder_prompt_ids
|
||||||
|
})
|
||||||
|
return mm_inputs
|
||||||
|
|||||||
@ -144,7 +144,11 @@ class MultiModalProfiler(Generic[_I]):
|
|||||||
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
|
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_dummy_data(self, seq_len: int) -> DummyData:
|
def get_dummy_data(
|
||||||
|
self,
|
||||||
|
seq_len: int,
|
||||||
|
is_encoder_data: bool = False,
|
||||||
|
) -> DummyData:
|
||||||
# Avoid circular import
|
# Avoid circular import
|
||||||
from vllm.sequence import SequenceData
|
from vllm.sequence import SequenceData
|
||||||
|
|
||||||
@ -183,16 +187,18 @@ class MultiModalProfiler(Generic[_I]):
|
|||||||
total_len = len(prompt_token_ids)
|
total_len = len(prompt_token_ids)
|
||||||
|
|
||||||
# V0 does not support chunked prefill.
|
# V0 does not support chunked prefill.
|
||||||
if total_len > seq_len and not envs.VLLM_USE_V1:
|
if (total_len > seq_len and not envs.VLLM_USE_V1) or is_encoder_data:
|
||||||
|
if total_len > seq_len:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"The context length (%d) of the model is too short "
|
"The context length (%d) of the model is too short "
|
||||||
"to hold the multi-modal embeddings in the worst case "
|
"to hold the multi-modal embeddings in the worst case "
|
||||||
"(%d tokens in total, out of which %s are reserved for "
|
"(%d tokens in total, out of which %s are reserved for "
|
||||||
"multi-modal embeddings). This may cause certain multi-modal "
|
"multi-modal embeddings). This may cause certain "
|
||||||
"inputs to fail during inference, even when the input text is "
|
"multi-modal inputs to fail during inference, even when "
|
||||||
"short. To avoid this, you should increase `max_model_len`, "
|
"the input text is short. To avoid this, you should "
|
||||||
"reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len,
|
"increase `max_model_len`, reduce `max_num_seqs`, "
|
||||||
total_len, total_placeholders_by_modality)
|
"and/or reduce `mm_counts`.", seq_len, total_len,
|
||||||
|
total_placeholders_by_modality)
|
||||||
|
|
||||||
return DummyData(
|
return DummyData(
|
||||||
seq_data=SequenceData.from_prompt_token_counts((0, seq_len)),
|
seq_data=SequenceData.from_prompt_token_counts((0, seq_len)),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user