From fc0f87768aa5fe253858d0a7ed5f0dbce3f64ba3 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Mon, 7 Apr 2025 12:07:15 +0800 Subject: [PATCH] [Bugfix] Make dummy encoder prompt padding alternative and add missing warnings (#16129) Signed-off-by: Isotr0py <2037008807@qq.com> --- .../multimodal/processing/test_mllama.py | 71 +++++++++++++++++++ tests/models/utils.py | 3 + vllm/model_executor/models/whisper.py | 4 ++ vllm/multimodal/processing.py | 4 ++ vllm/multimodal/profiling.py | 30 ++++++-- 5 files changed, 108 insertions(+), 4 deletions(-) create mode 100644 tests/models/multimodal/processing/test_mllama.py diff --git a/tests/models/multimodal/processing/test_mllama.py b/tests/models/multimodal/processing/test_mllama.py new file mode 100644 index 0000000000000..b89376cf17229 --- /dev/null +++ b/tests/models/multimodal/processing/test_mllama.py @@ -0,0 +1,71 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for mllama's multimodal preprocessing and profiling.""" +import pytest +from transformers import MllamaConfig + +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.profiling import MultiModalProfiler + +from ...utils import build_model_context + + +@pytest.mark.parametrize("model_id", + ["meta-llama/Llama-3.2-11B-Vision-Instruct"]) +@pytest.mark.parametrize("max_model_len", [4096, 8192, 25600, 131072]) +@pytest.mark.parametrize("max_num_seqs", [1, 2, 8]) +def test_profiling( + model_id: str, + max_model_len: int, + max_num_seqs: int, +): + # regression test for https://github.com/vllm-project/vllm/issues/13929 + from vllm.model_executor.models.mllama import calc_token_per_chunk + + model_config_kwargs = { + "max_model_len": max_model_len, + } + ctx = build_model_context( + model_id, + model_config_kwargs=model_config_kwargs, + limit_mm_per_prompt={"image": 1}, + ) + + mm_config = ctx.get_mm_config() + processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config) + profiler = MultiModalProfiler(processor) + + dummy_encoder_data = profiler.get_encoder_dummy_data( + max_model_len, + mm_counts=mm_config.limit_per_prompt, + ) + dummy_mm_data = processor.dummy_inputs.get_dummy_processor_inputs( + max_model_len, + mm_counts=mm_config.limit_per_prompt, + ) + + hf_config = ctx.get_hf_config(MllamaConfig) + image_size = hf_config.vision_config.image_size + encoder_seq_lens = [len(dummy_encoder_data.prompt_token_ids) + ] * max_num_seqs + + mm_kwargs = processor.apply( + prompt=dummy_mm_data.prompt_text, + mm_data=dummy_mm_data.mm_data, + hf_processor_mm_kwargs=dict(), + )["mm_kwargs"] + + # Get the actual number of encoder tokens for each sample. + # Because attn_metadata.encoder_seq_lens only counts the last + # group of images for each sample, which is used to cheat the + # block manager to allocate blocks for those images only. + # See MllamaMultiModalProcessor for more details. + num_tiles = [[t] for t in mm_kwargs.pop("num_tiles")] + num_tokens_per_tile = calc_token_per_chunk(image_size) + actual_encoder_seq_lens = [ + sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles + ] + + # simulate mllama image-present prefill. + for actual_len, last_group_len in zip(actual_encoder_seq_lens, + encoder_seq_lens): + assert actual_len >= last_group_len diff --git a/tests/models/utils.py b/tests/models/utils.py index 7109169e89966..5407540114b4c 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -255,6 +255,7 @@ def build_model_context( model_id: str, task: TaskOption = "auto", dtype: Union[str, torch.dtype] = "auto", + model_config_kwargs: Optional[dict[str, Any]] = None, mm_processor_kwargs: Optional[dict[str, Any]] = None, limit_mm_per_prompt: Optional[dict[str, int]] = None, disable_mm_preprocessor_cache: bool = True, @@ -274,6 +275,7 @@ def build_model_context( model_info.check_available_online(on_fail="skip") model_info.check_transformers_version(on_fail="skip") + model_config_kwargs = model_config_kwargs or {} model_config = ModelConfig( model_id, task=task, @@ -286,5 +288,6 @@ def build_model_context( limit_mm_per_prompt=limit_mm_per_prompt, disable_mm_preprocessor_cache=disable_mm_preprocessor_cache, hf_overrides=model_info.hf_overrides, + **model_config_kwargs, ) return InputContext(model_config) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index eb6404922c6d0..e83abbe8b2527 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -580,6 +580,10 @@ class WhisperMultiModalProcessor( feature_extractor = self.info.get_feature_extractor() return MultiModalDataParser(target_sr=feature_extractor.sampling_rate) + @property + def pad_dummy_encoder_prompt(self) -> bool: + return True + def create_encoder_prompt( self, prompt: Union[str, list[int]], diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index c8864c33fe372..00c0f87b0b237 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1654,6 +1654,10 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): """ raise NotImplementedError + @property + def pad_dummy_encoder_prompt(self) -> bool: + return False + def create_decoder_prompt( self, prompt: Union[str, list[int]], diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index 1df9a1f5eba1c..ea58ba699f373 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -15,7 +15,8 @@ from vllm.logger import init_logger from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, MultiModalInputs, MultiModalKwargs, MultiModalPlaceholderDict) -from .processing import BaseMultiModalProcessor, BaseProcessingInfo +from .processing import (BaseMultiModalProcessor, BaseProcessingInfo, + EncDecMultiModalProcessor) logger = init_logger(__name__) @@ -200,7 +201,10 @@ class MultiModalProfiler(Generic[_I]): seq_len: int, mm_counts: Optional[Mapping[str, int]] = None, ) -> DummyEncoderData: - mm_inputs, _ = self.get_and_validate_mm_inputs(seq_len, mm_counts) + ( + mm_inputs, + total_placeholders_by_modality, + ) = self.get_and_validate_mm_inputs(seq_len, mm_counts) mm_inputs = cast(MultiModalEncDecInputs, mm_inputs) # For encoder-decoder models, use encoder prompt token ids instead of @@ -208,8 +212,26 @@ class MultiModalProfiler(Generic[_I]): encoder_prompt_token_ids = mm_inputs["encoder_prompt_token_ids"] total_len = len(encoder_prompt_token_ids) - num_tokens_to_pad = max(total_len, seq_len) - total_len - encoder_prompt_token_ids.extend([0] * num_tokens_to_pad) + + # Encoder-decoder multimodal models only support v0 + if total_len > seq_len: + # `max_num_batched_tokens` is defined by `SchedulerConfig` + logger.warning( + "The encoder sequence length used for profiling (" + "max_num_batched_tokens / max_num_seqs = %d) is too short " + "to hold the multi-modal embeddings in the worst case " + "(%d tokens in total, out of which %s are reserved for " + "multi-modal embeddings). This may cause certain " + "multi-modal inputs to fail during inference, even when " + "the input text is short. To avoid this, you should " + "increase `max_model_len`, reduce `max_num_seqs`, " + "and/or reduce `mm_counts`.", seq_len, total_len, + total_placeholders_by_modality) + + processor = cast(EncDecMultiModalProcessor, self.processor) + if processor.pad_dummy_encoder_prompt: + num_tokens_to_pad = max(total_len, seq_len) - total_len + encoder_prompt_token_ids.extend([0] * num_tokens_to_pad) return DummyEncoderData(encoder_prompt_token_ids)