diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 572eabe261930..eb56b0aee6c76 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -68,7 +68,7 @@ def run_blip2(questions: list[str], modality: str) -> ModelRequestData: # See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa prompts = [f"Question: {question} Answer:" for question in questions] engine_args = EngineArgs( - model="Salesforce/blip2-opt-2.7b", + model="Salesforce/blip2-opt-6.7b", disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, ) @@ -128,7 +128,8 @@ def run_florence2(questions: list[str], modality: str) -> ModelRequestData: engine_args = EngineArgs( model="microsoft/Florence-2-large", tokenizer="facebook/bart-large", - max_num_seqs=8, + max_model_len=4096, + max_num_seqs=2, trust_remote_code=True, dtype="bfloat16", disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, @@ -511,7 +512,7 @@ def run_mllama(questions: list[str], modality: str) -> ModelRequestData: engine_args = EngineArgs( model=model_name, max_model_len=4096, - max_num_seqs=16, + max_num_seqs=2, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, ) @@ -700,7 +701,7 @@ def run_pixtral_hf(questions: list[str], modality: str) -> ModelRequestData: # NOTE: Need L40 (or equivalent) to avoid OOM engine_args = EngineArgs( model=model_name, - max_model_len=8192, + max_model_len=6144, max_num_seqs=2, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, ) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 6277a1009ffe4..05e30f855ced2 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -217,7 +217,7 @@ EMBEDDING_MODELS = { # type: ignore[var-annotated] MULTIMODAL_MODELS = { # [Decoder-only] - "Salesforce/blip2-opt-2.7b": PPTestSettings.fast(), + "Salesforce/blip2-opt-6.7b": PPTestSettings.fast(), "facebook/chameleon-7b": PPTestSettings.fast(), "adept/fuyu-8b": PPTestSettings.fast(), "THUDM/glm-4v-9b": PPTestSettings.fast(), diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index aa3ac7eea6d0d..7a9158eff94eb 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -34,8 +34,6 @@ REQUIRES_V0_MODELS = [ # V1 Test: no way to fall back for head_dim = 80 # https://github.com/vllm-project/vllm/issues/14524 "qwen_vl", - "h2ovl", - "blip2", # V1 Test: not enough KV cache space in C1. "fuyu", ] @@ -161,7 +159,8 @@ VLM_TEST_SETTINGS = { marks=[large_gpu_mark(min_gb=64)], ), "blip2": VLMTestInfo( - models=["Salesforce/blip2-opt-2.7b"], + # TODO: Change back to 2.7b once head_dim = 80 is supported + models=["Salesforce/blip2-opt-6.7b"], test_type=VLMTestType.IMAGE, prompt_formatter=lambda img_prompt: f"Question: {img_prompt} Answer:", img_idx_to_prompt=lambda idx: "", @@ -248,7 +247,8 @@ VLM_TEST_SETTINGS = { "h2ovl": VLMTestInfo( models = [ "h2oai/h2ovl-mississippi-800m", - "h2oai/h2ovl-mississippi-2b", + # TODO: Re-enable once head_dim = 80 is supported + # "h2oai/h2ovl-mississippi-2b", ], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), prompt_formatter=lambda img_prompt: f"<|prompt|>{img_prompt}<|end|><|answer|>", # noqa: E501 diff --git a/tests/models/registry.py b/tests/models/registry.py index 7c8fac08befff..69ebfe4c92415 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -259,7 +259,8 @@ _CROSS_ENCODER_EXAMPLE_MODELS = { _MULTIMODAL_EXAMPLE_MODELS = { # [Decoder-only] "AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"), - "Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b"), # noqa: E501 + "Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b", # noqa: E501 + extras={"6b": "Salesforce/blip2-opt-6.7b"}), # noqa: E501 "ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501 "DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501 extras={"fork": "Isotr0py/deepseek-vl2-tiny"}, # noqa: E501 diff --git a/vllm/model_executor/models/florence2.py b/vllm/model_executor/models/florence2.py index 3883cd4460f50..02535cc5473c7 100644 --- a/vllm/model_executor/models/florence2.py +++ b/vllm/model_executor/models/florence2.py @@ -875,7 +875,8 @@ class Florence2MultiModalProcessor( Florence2MultiModalProcessor, info=Florence2ProcessingInfo, dummy_inputs=Florence2DummyInputsBuilder) -class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal): +class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsV0Only): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index a1004cd0ac608..a807b047a1aae 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -39,7 +39,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors -from vllm.utils import flatten_2d_lists from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, @@ -66,10 +65,13 @@ class FuyuImagePatchInputs(TypedDict): This is used to split the embeddings which has the first two dimensions flattened just like `flat_data`. """ + embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] """ A boolean mask indicating which image embeddings correspond to patch tokens. + + Shape: `(batch_size * num_images, num_embeds)` """ @@ -322,16 +324,18 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[FuyuImagePatchInputs]: image_patches = kwargs.pop("image_patches", None) - embed_is_patch = kwargs.pop("embed_is_patch", None) if image_patches is not None: if not isinstance(image_patches, (torch.Tensor, list)): raise ValueError("Incorrect type of image patches. " f"Got type: {type(image_patches)}") + embed_is_patch = kwargs.pop("embed_is_patch") if not isinstance(embed_is_patch, (torch.Tensor, list)): raise ValueError("Incorrect type of embed_is_patch. " f"Got type: {type(embed_is_patch)}") + image_patches_flat = flatten_bn(image_patches) + embed_is_patch = flatten_bn(embed_is_patch) return FuyuImagePatchInputs( type="image_patches", @@ -351,6 +355,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): assert self.vision_embed_tokens is not None vision_embeddings_flat, _ = self.vision_embed_tokens( image_patches_flat) + return vision_embeddings_flat.split(patches_per_image, dim=0) def get_multimodal_embeddings( @@ -358,13 +363,13 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return None - vision_embeddings = self._process_image_input(image_input) - #return vision_embeddings - return flatten_2d_lists( - scatter_patch_features(*args) for args in zip( - vision_embeddings, - image_input["embed_is_patch"], - )) + + image_features = self._process_image_input(image_input) + + return scatter_patch_features( + image_features, + image_input["embed_is_patch"], + ) def get_input_embeddings( self, diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 9efb57b8c5aa1..bbdea70a7bcfd 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -613,7 +613,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, def _process_image_input( self, image_input: Gemma3ImageInputs, - ) -> tuple[torch.Tensor, ...]: + ) -> list[torch.Tensor]: assert self.vision_tower is not None pixel_values = image_input["pixel_values"] @@ -625,7 +625,9 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ) image_embeds = self.multi_modal_projector(image_features) - return image_embeds.split(num_patches.tolist()) + return [ + e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist()) + ] def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 327ec4640f03e..da4a44346c32e 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -733,7 +733,10 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, pixel_attention_mask=pixel_attention_mask, ) - def _process_image_input(self, image_input: ImageInputs) -> torch.Tensor: + def _process_image_input( + self, + image_input: ImageInputs, + ) -> Union[torch.Tensor, list[torch.Tensor]]: if image_input["type"] == "image_embeds": return image_input["data"] @@ -741,7 +744,9 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, image_features = self.model.connector(image_features) num_patches = image_input["num_patches"] - return image_features.split(num_patches.tolist()) + return [ + e.flatten(0, 1) for e in image_features.split(num_patches.tolist()) + ] def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 8a5edefb4a0b2..780af72d57201 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -406,20 +406,21 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, h, w) stacked_embeddings = self._video_pixels_to_features( self.vision_tower, stacked_pixels) - return stacked_embeddings.view(b, num_frames, - *stacked_embeddings.shape[1:]) + embeds = stacked_embeddings.view(b, num_frames, + *stacked_embeddings.shape[1:]) elif is_list_of(video_pixels, torch.Tensor): frames_per_videos = [v.shape[0] for v in video_pixels] stacked_pixels = torch.cat(video_pixels, dim=0) stacked_embeddings = self._video_pixels_to_features( self.vision_tower, stacked_pixels) - return torch.split(stacked_embeddings, frames_per_videos, dim=0) - + embeds = torch.split(stacked_embeddings, frames_per_videos, dim=0) else: raise ValueError( f"Unsupported type of video input {type(video_pixels)}") + return [e.flatten(0, 1) for e in embeds] + def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: video_input = self._parse_and_validate_video_input(**kwargs) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 2c0d37e883b90..5fab9df3f8f99 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -919,8 +919,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): image_features_flat = self.get_vision_hidden_states(image_input) - # Reconstruct the batch dimension - return image_features_flat.split(image_input["num_slices"].tolist()) + num_slices = image_input["num_slices"] + return [ + e.flatten(0, 1) + for e in image_features_flat.split(num_slices.tolist()) + ] def _process_multimodal_inputs(self, modalities: dict): # The result multimodal_embeddings is tuple of tensors, with each diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index db069f8de2a35..5c21fb2d4ad2e 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -204,7 +204,7 @@ def scatter_patch_features( (e_is_patch.shape[0], patches_one.shape[-1]), fill_value=torch.nan, ) - embed_one[e_is_patch] = patches_one.flatten(0, -2) + embed_one[e_is_patch] = patches_one return embed_one return tuple( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 74f3124e3c779..c7374cc3d3306 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -41,6 +41,8 @@ from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin +from .utils import sanity_check_mm_encoder_outputs + if TYPE_CHECKING: import xgrammar as xgr @@ -867,6 +869,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): curr_group_outputs = self.model.get_multimodal_embeddings( **batched_mm_inputs) + sanity_check_mm_encoder_outputs( + curr_group_outputs, + expected_num_items=len(grouped_mm_inputs), + ) + for output in curr_group_outputs: encoder_outputs.append(output) @@ -1490,12 +1497,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Run multimodal encoder. dummy_encoder_outputs = self.model.get_multimodal_embeddings( **batched_dummy_mm_inputs) - assert len(dummy_encoder_outputs) == max_num_mm_items, ( - "Expected dimension 0 of encoder outputs to match the number " - f"of multimodal data items: {max_num_mm_items}, got " - f"{len(dummy_encoder_outputs)=} instead. This is most likely " - "due to the 'get_multimodal_embeddings' method of the model " - "not implemented correctly.") + + sanity_check_mm_encoder_outputs( + dummy_encoder_outputs, + expected_num_items=max_num_mm_items, + ) # Cache the dummy encoder outputs. self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index ea5a17016eb6b..8f6a54892a4e6 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -37,6 +37,8 @@ from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch +from .utils import sanity_check_mm_encoder_outputs + if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -512,6 +514,11 @@ class TPUModelRunner: curr_group_outputs = self.model.get_multimodal_embeddings( **batched_mm_inputs) + sanity_check_mm_encoder_outputs( + curr_group_outputs, + expected_num_items=len(grouped_mm_inputs), + ) + for output in curr_group_outputs: encoder_outputs.append(output) diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py new file mode 100644 index 0000000000000..b1d3aa7cd8afb --- /dev/null +++ b/vllm/v1/worker/utils.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: Apache-2.0 +import torch + + +def sanity_check_mm_encoder_outputs( + mm_embeddings: object, + expected_num_items: int, +) -> None: + """ + Perform sanity checks for the result of + :meth:`vllm.model_executor.models.SupportsMultiModal.get_multimodal_embeddings`. + """ + assert isinstance(mm_embeddings, (list, tuple, torch.Tensor)), ( + "Expected multimodal embeddings to be a list/tuple of 2D tensors, " + f"or a single 3D tensor, but got {type(mm_embeddings)} " + "instead. This is most likely due to incorrect implementation " + "of the model's `get_multimodal_embeddings` method.") + + assert len(mm_embeddings) == expected_num_items, ( + "Expected number of multimodal embeddings to match number of " + f"input items: {expected_num_items}, but got {len(mm_embeddings)=} " + "instead. This is most likely due to incorrect implementation " + "of the model's `get_multimodal_embeddings` method.") + + assert all(e.ndim == 2 for e in mm_embeddings), ( + "Expected multimodal embeddings to be a sequence of 2D tensors, " + f"but got tensors with shapes {[e.shape for e in mm_embeddings]} " + "instead. This is most likely due to incorrect implementation " + "of the model's `get_multimodal_embeddings` method.")