[Bugfix] Always apply MM processor even when no MM items are passed (#26240)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-10-05 18:10:20 +08:00 committed by GitHub
parent 432e1cbc23
commit b7e8e4e6be
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 102 additions and 30 deletions

View File

@ -46,7 +46,6 @@ from vllm.connections import global_http_connection
from vllm.distributed import (cleanup_dist_env_and_memory,
init_distributed_environment,
initialize_model_parallel)
from vllm.inputs import TextPrompt
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.multimodal.utils import fetch_image
@ -760,17 +759,24 @@ class VllmRunner:
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
) -> list[TextPrompt]:
) -> list[dict[str, Any]]:
if any(x is not None and len(x) != len(prompts)
for x in [images, videos, audios]):
raise ValueError(
"All non-None multimodal inputs must have the same length as "
"prompts")
inputs = []
inputs = list[dict[str, Any]]()
for i, prompt in enumerate(prompts):
multi_modal_data = {}
prompt_dict = dict[str, Any]()
if isinstance(prompt, str):
prompt_dict["prompt"] = prompt
elif isinstance(prompt, list):
prompt_dict["prompt_token_ids"] = prompt
else:
prompt_dict["prompt_embeds"] = prompt
multi_modal_data = dict[str, Any]()
if images is not None and (image := images[i]) is not None:
multi_modal_data["image"] = image
if videos is not None and (video := videos[i]) is not None:
@ -778,17 +784,10 @@ class VllmRunner:
if audios is not None and (audio := audios[i]) is not None:
multi_modal_data["audio"] = audio
text_prompt_kwargs: dict[str, Any] = {
"multi_modal_data": multi_modal_data or None
}
if isinstance(prompt, str):
text_prompt_kwargs["prompt"] = prompt
elif isinstance(prompt, list):
text_prompt_kwargs["prompt_token_ids"] = prompt
else:
text_prompt_kwargs["prompt_embeds"] = prompt
if multi_modal_data:
prompt_dict["multi_modal_data"] = multi_modal_data
inputs.append(TextPrompt(**text_prompt_kwargs))
inputs.append(prompt_dict)
return inputs

View File

@ -3,8 +3,11 @@
import pytest
from vllm.config import ModelConfig
from vllm.inputs import zip_enc_dec_prompts
from vllm.inputs.parse import parse_raw_prompts
from vllm.inputs.preprocess import InputPreprocessor
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
pytestmark = pytest.mark.cpu_test
@ -80,3 +83,50 @@ def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs):
assert zipped['encoder_prompt'] == enc
assert zipped['decoder_prompt'] == dec
assert zipped['mm_processor_kwargs'] == exp_kwargs
@pytest.mark.parametrize("model_id", [
"facebook/opt-125m",
])
@pytest.mark.parametrize("prompt", [
{
"prompt": "",
"multi_modal_data": {
"dummy": []
},
},
{
"prompt_token_ids": [],
"multi_modal_data": {
"dummy": []
},
},
])
def test_preprocessor_text_no_mm_inputs(model_id, prompt):
model_config = ModelConfig(model=model_id)
tokenizer = init_tokenizer_from_configs(model_config)
input_preprocessor = InputPreprocessor(model_config, tokenizer)
with pytest.raises(ValueError, match="does not support multimodal inputs"):
input_preprocessor.preprocess(prompt)
@pytest.mark.parametrize("model_id", [
"facebook/chameleon-7b",
])
@pytest.mark.parametrize("prompt", [
"",
{
"prompt_token_ids": []
},
])
def test_preprocessor_always_mm_code_path(model_id, prompt):
model_config = ModelConfig(model=model_id)
tokenizer = init_tokenizer_from_configs(model_config)
input_preprocessor = InputPreprocessor(model_config, tokenizer)
# HF processor adds sep token
sep_token_id = tokenizer.vocab[tokenizer.sep_token]
processed_inputs = input_preprocessor.preprocess(prompt)
assert sep_token_id in processed_inputs["prompt_token_ids"]

View File

@ -314,15 +314,19 @@ class InputPreprocessor:
parsed_content["prompt_token_ids"], tokenization_kwargs)
inputs: Union[TokenInputs, MultiModalInputs]
if multi_modal_data := parsed_content.get("multi_modal_data"):
if self.model_config.is_multimodal_model:
inputs = self._process_multimodal(
prompt_token_ids,
multi_modal_data,
parsed_content.get("multi_modal_data", {}),
parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
else:
if parsed_content.get("multi_modal_data"):
raise ValueError(
"This model does not support multimodal inputs")
inputs = token_inputs(prompt_token_ids)
if cache_salt := parsed_content.get("cache_salt"):
@ -340,15 +344,19 @@ class InputPreprocessor:
prompt_text = parsed_content["prompt"]
inputs: Union[TokenInputs, MultiModalInputs]
if multi_modal_data := parsed_content.get("multi_modal_data"):
if self.model_config.is_multimodal_model:
inputs = self._process_multimodal(
prompt_text,
multi_modal_data,
parsed_content.get("multi_modal_data", {}),
parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
else:
if parsed_content.get("multi_modal_data"):
raise ValueError(
"This model does not support multimodal inputs")
prompt_token_ids = self._tokenize_prompt(
prompt_text,
tokenization_kwargs=tokenization_kwargs,

View File

@ -507,8 +507,8 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
)
# Keep the behavior in line with HF processor
if token_ids[:2] == tokenizer.encode("<s> <|image|>",
add_special_tokens=False):
if len(mm_prompt_updates) and (token_ids[:2] == tokenizer.encode(
"<s> <|image|>", add_special_tokens=False)):
token_ids = [token_ids[0], *token_ids[2:]]
placeholders = {
modality: [

View File

@ -331,6 +331,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
"""
mm_item_counts = mm_items.get_all_counts()
self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
self._validate_mm_updates(mm_prompt_updates, mm_item_counts)
use_audio_in_video = False
if "video" in mm_kwargs:

View File

@ -1946,6 +1946,24 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
"model (usually arising from an inconsistency between "
"`_call_hf_processor` and `_get_mm_fields_config`).")
def _validate_mm_updates(
self,
mm_updates: MultiModalPromptUpdates,
mm_item_counts: Mapping[str, int],
) -> None:
for modality, item_count in mm_item_counts.items():
placeholders = mm_updates.get(modality, [])
if len(placeholders) != item_count:
raise RuntimeError(
f"Expected there to be {item_count} prompt updates "
f"corresponding to {item_count} {modality} items, but "
f"instead found {len(placeholders)} prompt updates! "
"This is likely because you forgot to include input "
"placeholder tokens (e.g., `<image>`, `<|image_pad|>`) "
"in the prompt. If the model has a chat template, make "
"sure you have applied it before calling `LLM.generate`.")
def _validate_mm_placeholders(
self,
mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
@ -1955,17 +1973,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
placeholders = mm_placeholders.get(modality, [])
if len(placeholders) != item_count:
# NOTE: If you are a model developer, this can also arise from
# an inconsistency between `_call_hf_processor` and
# `_get_mm_fields_config` implementations
raise RuntimeError(
f"Expected there to be {item_count} prompt updates "
f"Expected there to be {item_count} prompt placeholders "
f"corresponding to {item_count} {modality} items, but "
f"instead found {len(placeholders)} prompt updates! "
"This is likely because you forgot to include input "
"placeholder tokens (e.g., `<image>`, `<|image_pad|>`) "
"in the prompt. If the model has a chat template, make "
"sure you have applied it before calling `LLM.generate`.")
f"instead found {len(placeholders)} prompt placeholders! "
"Make sure the implementation of `_call_hf_processor` and "
"`_get_mm_fields_config` are consistent with each other.")
def _maybe_apply_prompt_updates(
self,
@ -1977,6 +1990,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
) -> tuple[list[int], Mapping[str, list[PlaceholderFeaturesInfo]]]:
mm_item_counts = mm_items.get_all_counts()
self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
self._validate_mm_updates(mm_prompt_updates, mm_item_counts)
if is_update_applied:
mm_placeholders = self._find_mm_placeholders(