[Bugfix] Make dummy encoder prompt padding alternative and add missing warnings (#16129)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2025-04-07 12:07:15 +08:00 committed by GitHub
parent 0a57386721
commit fc0f87768a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 108 additions and 4 deletions

View File

@ -0,0 +1,71 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for mllama's multimodal preprocessing and profiling."""
import pytest
from transformers import MllamaConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.profiling import MultiModalProfiler
from ...utils import build_model_context
@pytest.mark.parametrize("model_id",
["meta-llama/Llama-3.2-11B-Vision-Instruct"])
@pytest.mark.parametrize("max_model_len", [4096, 8192, 25600, 131072])
@pytest.mark.parametrize("max_num_seqs", [1, 2, 8])
def test_profiling(
model_id: str,
max_model_len: int,
max_num_seqs: int,
):
# regression test for https://github.com/vllm-project/vllm/issues/13929
from vllm.model_executor.models.mllama import calc_token_per_chunk
model_config_kwargs = {
"max_model_len": max_model_len,
}
ctx = build_model_context(
model_id,
model_config_kwargs=model_config_kwargs,
limit_mm_per_prompt={"image": 1},
)
mm_config = ctx.get_mm_config()
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
profiler = MultiModalProfiler(processor)
dummy_encoder_data = profiler.get_encoder_dummy_data(
max_model_len,
mm_counts=mm_config.limit_per_prompt,
)
dummy_mm_data = processor.dummy_inputs.get_dummy_processor_inputs(
max_model_len,
mm_counts=mm_config.limit_per_prompt,
)
hf_config = ctx.get_hf_config(MllamaConfig)
image_size = hf_config.vision_config.image_size
encoder_seq_lens = [len(dummy_encoder_data.prompt_token_ids)
] * max_num_seqs
mm_kwargs = processor.apply(
prompt=dummy_mm_data.prompt_text,
mm_data=dummy_mm_data.mm_data,
hf_processor_mm_kwargs=dict(),
)["mm_kwargs"]
# Get the actual number of encoder tokens for each sample.
# Because attn_metadata.encoder_seq_lens only counts the last
# group of images for each sample, which is used to cheat the
# block manager to allocate blocks for those images only.
# See MllamaMultiModalProcessor for more details.
num_tiles = [[t] for t in mm_kwargs.pop("num_tiles")]
num_tokens_per_tile = calc_token_per_chunk(image_size)
actual_encoder_seq_lens = [
sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles
]
# simulate mllama image-present prefill.
for actual_len, last_group_len in zip(actual_encoder_seq_lens,
encoder_seq_lens):
assert actual_len >= last_group_len

View File

@ -255,6 +255,7 @@ def build_model_context(
model_id: str,
task: TaskOption = "auto",
dtype: Union[str, torch.dtype] = "auto",
model_config_kwargs: Optional[dict[str, Any]] = None,
mm_processor_kwargs: Optional[dict[str, Any]] = None,
limit_mm_per_prompt: Optional[dict[str, int]] = None,
disable_mm_preprocessor_cache: bool = True,
@ -274,6 +275,7 @@ def build_model_context(
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")
model_config_kwargs = model_config_kwargs or {}
model_config = ModelConfig(
model_id,
task=task,
@ -286,5 +288,6 @@ def build_model_context(
limit_mm_per_prompt=limit_mm_per_prompt,
disable_mm_preprocessor_cache=disable_mm_preprocessor_cache,
hf_overrides=model_info.hf_overrides,
**model_config_kwargs,
)
return InputContext(model_config)

View File

@ -580,6 +580,10 @@ class WhisperMultiModalProcessor(
feature_extractor = self.info.get_feature_extractor()
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
@property
def pad_dummy_encoder_prompt(self) -> bool:
return True
def create_encoder_prompt(
self,
prompt: Union[str, list[int]],

View File

@ -1654,6 +1654,10 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
"""
raise NotImplementedError
@property
def pad_dummy_encoder_prompt(self) -> bool:
return False
def create_decoder_prompt(
self,
prompt: Union[str, list[int]],

View File

@ -15,7 +15,8 @@ from vllm.logger import init_logger
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs, MultiModalKwargs,
MultiModalPlaceholderDict)
from .processing import BaseMultiModalProcessor, BaseProcessingInfo
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
EncDecMultiModalProcessor)
logger = init_logger(__name__)
@ -200,7 +201,10 @@ class MultiModalProfiler(Generic[_I]):
seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None,
) -> DummyEncoderData:
mm_inputs, _ = self.get_and_validate_mm_inputs(seq_len, mm_counts)
(
mm_inputs,
total_placeholders_by_modality,
) = self.get_and_validate_mm_inputs(seq_len, mm_counts)
mm_inputs = cast(MultiModalEncDecInputs, mm_inputs)
# For encoder-decoder models, use encoder prompt token ids instead of
@ -208,8 +212,26 @@ class MultiModalProfiler(Generic[_I]):
encoder_prompt_token_ids = mm_inputs["encoder_prompt_token_ids"]
total_len = len(encoder_prompt_token_ids)
num_tokens_to_pad = max(total_len, seq_len) - total_len
encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
# Encoder-decoder multimodal models only support v0
if total_len > seq_len:
# `max_num_batched_tokens` is defined by `SchedulerConfig`
logger.warning(
"The encoder sequence length used for profiling ("
"max_num_batched_tokens / max_num_seqs = %d) is too short "
"to hold the multi-modal embeddings in the worst case "
"(%d tokens in total, out of which %s are reserved for "
"multi-modal embeddings). This may cause certain "
"multi-modal inputs to fail during inference, even when "
"the input text is short. To avoid this, you should "
"increase `max_model_len`, reduce `max_num_seqs`, "
"and/or reduce `mm_counts`.", seq_len, total_len,
total_placeholders_by_modality)
processor = cast(EncDecMultiModalProcessor, self.processor)
if processor.pad_dummy_encoder_prompt:
num_tokens_to_pad = max(total_len, seq_len) - total_len
encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
return DummyEncoderData(encoder_prompt_token_ids)