[Bugfix] Check that number of images matches number of <|image|> tokens with mllama (#13911)

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
This commit is contained in:
Travis Johnson 2025-02-27 21:00:45 -07:00 committed by GitHub
parent 6c85da3a18
commit 73e0225ee9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 26 additions and 3 deletions

View File

@ -479,8 +479,9 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens,
# Regression tests for https://github.com/vllm-project/vllm/issues/10648 # Regression tests for https://github.com/vllm-project/vllm/issues/10648
# Number of image tags is greater than the number of images provided # Number of groups of image tokens is greater than the number of images
prompt = "<|begin_of_text|><|image|><|image|> Compare the two images" # noqa: E501 # provided (the whitespace between the tags is necessary)
prompt = "<|begin_of_text|><|image|> <|image|> Compare the two images" # noqa: E501
image = stop_sign image = stop_sign
with pytest.raises(ValueError): with pytest.raises(ValueError):
vllm_model.generate_greedy_logprobs([prompt], vllm_model.generate_greedy_logprobs([prompt],

View File

@ -54,7 +54,8 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs from vllm.multimodal.inputs import (MultiModalEncDecInputs,
MultiModalFieldConfig, MultiModalKwargs)
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataDict, MultiModalDataItems) MultiModalDataDict, MultiModalDataItems)
from vllm.multimodal.processing import (BaseProcessingInfo, from vllm.multimodal.processing import (BaseProcessingInfo,
@ -169,6 +170,27 @@ class MllamaDummyInputsBuilder(BaseDummyInputsBuilder[MllamaProcessingInfo]):
class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo] class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
): ):
def apply(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalEncDecInputs:
mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
# Check that the number of image tokens in the decoder prompt matches
# the number of images provided in mm_data
num_image_tokens = mm_inputs['prompt_token_ids'].count(
self.info.get_hf_config().image_token_index)
image_data = mm_data.get("image", [])
num_images = 1 if isinstance(image_data, Image) else len(image_data)
if num_image_tokens != num_images:
raise ValueError(
f"The number of image tokens ({num_image_tokens}) must be"
f" the same as the number of images ({num_images})")
return mm_inputs
def _call_hf_processor( def _call_hf_processor(
self, self,
prompt: str, prompt: str,