From e6c9053f9ec0b41e9af41def67537a4a3097eeb5 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 27 Mar 2025 15:45:00 +0800 Subject: [PATCH] [Misc] Clean up `scatter_patch_features` (#15559) Signed-off-by: DarkLight1337 --- vllm/model_executor/models/gemma3_mm.py | 17 ++-- vllm/model_executor/models/internvl.py | 21 ++--- vllm/model_executor/models/llava.py | 22 +++-- vllm/model_executor/models/molmo.py | 105 ++++++++---------------- vllm/model_executor/models/pixtral.py | 18 ++-- vllm/model_executor/models/vision.py | 35 ++++---- 6 files changed, 82 insertions(+), 136 deletions(-) diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 63d3ccbf54bc2..9efb57b8c5aa1 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -30,7 +30,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, # yapf: enable from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors -from vllm.utils import flatten_2d_lists from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) @@ -60,7 +59,7 @@ class Gemma3ImagePixelInputs(TypedDict): A boolean mask indicating which image embeddings correspond to patch tokens. - Shape: `(batch_size, num_images, num_embeds)` + Shape: `(batch_size * num_images, num_embeds)` """ @@ -593,6 +592,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, pixel_values = flatten_bn(pixel_values, concat=True) num_crops = flatten_bn(num_crops, concat=True) + embed_is_patch = flatten_bn(embed_is_patch) return Gemma3ImagePixelInputs( type="pixel_values", @@ -635,14 +635,10 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, image_features = self._process_image_input(image_input) - if kwargs.get("v0_path", False): - return image_features - - return flatten_2d_lists( - scatter_patch_features(*args) for args in zip( - image_features, - image_input["embed_is_patch"], - )) + return scatter_patch_features( + image_features, + image_input["embed_is_patch"], + ) def get_input_embeddings( self, @@ -671,7 +667,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, # NOTE: In v1, inputs_embeds is always generated at model runner, this # condition is for v0 compatibility. elif inputs_embeds is None: - kwargs.update({"v0_path": True}) vision_embeddings = self.get_multimodal_embeddings(**kwargs) inputs_embeds = self.get_input_embeddings(input_ids, diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index e1aa371610353..0729f4c7d203c 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -35,7 +35,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import flatten_2d_lists from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, @@ -66,13 +65,13 @@ class InternVLImagePixelInputs(TypedDict): A boolean mask indicating which image embeddings correspond to patch tokens. - Shape: `(batch_size, num_images, num_embeds)` + Shape: `(batch_size * num_images, num_embeds)` """ class InternVLImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] - data: NestedTensors + data: Union[torch.Tensor, list[torch.Tensor]] """ A tensor of shape `(num_images, total_image_feature_size, hidden_size)` or a list of tensors of shape `(total_image_feature_size, hidden_size)` @@ -867,6 +866,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): pixel_values_flat = flatten_bn(pixel_values_flat, concat=True) image_num_patches = flatten_bn(image_num_patches, concat=True) + embed_is_patch = flatten_bn(embed_is_patch) return InternVLImagePixelInputs( type="pixel_values", @@ -881,7 +881,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): def _process_image_input( self, image_input: InternVLImageInputs, - ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + ) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]]: if image_input["type"] == "image_embeds": return image_input["data"] @@ -921,15 +921,13 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): image_features = self._process_image_input(image_input) - if (kwargs.get("v0_path", False) - or image_input["type"] != "pixel_values"): + if image_input["type"] != "pixel_values": return image_features - return flatten_2d_lists( - scatter_patch_features(*args) for args in zip( - image_features, - image_input["embed_is_patch"], - )) + return scatter_patch_features( + image_features, + image_input["embed_is_patch"], + ) def get_input_embeddings( self, @@ -964,7 +962,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): # NOTE: In v1, inputs_embeds is always generated at model runner, this # condition is for v0 compatibility. elif inputs_embeds is None: - kwargs.update({"v0_path": True}) vision_embeddings = self.get_multimodal_embeddings(**kwargs) inputs_embeds = self.get_input_embeddings(input_ids, vision_embeddings) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index d1014067d9d7c..826f04b37547b 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -35,7 +35,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors -from vllm.utils import flatten_2d_lists from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP @@ -73,7 +72,7 @@ class PixtralHFImagePixelInputs(TypedDict): A boolean mask indicating which image embeddings correspond to patch tokens. - Shape: `(batch_size, num_images, num_embeds)` + Shape: `(batch_size * num_images, num_embeds)` """ @@ -618,6 +617,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): raise ValueError("Incorrect type of embed_is_patch. " f"Got type: {type(embed_is_patch)}") + embed_is_patch = flatten_bn(embed_is_patch) + return PixtralHFImagePixelInputs( type="pixel_values_pixtral", pixel_values=flatten_bn(pixel_values), @@ -713,18 +714,16 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): if image_input is None: return None - vision_embeddings = self._process_image_input(image_input) + image_features = self._process_image_input(image_input) - if (kwargs.get("v0_path", False) - or image_input["type"] != "pixel_values_pixtral"): + if image_input["type"] != "pixel_values_pixtral": # The path is used for pixtral (V0 only) and llava (V0/V1) - return vision_embeddings + return image_features - return flatten_2d_lists( - scatter_patch_features(*args) for args in zip( - vision_embeddings, - image_input["embed_is_patch"], - )) + return scatter_patch_features( + image_features, + image_input["embed_is_patch"], + ) def get_input_embeddings( self, @@ -790,7 +789,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): # NOTE: In v1, inputs_embeds is always generated at model runner, this # condition is for v0 compatibility. elif inputs_embeds is None: - kwargs.update({"v0_path": True}) vision_embeddings = self.get_multimodal_embeddings(**kwargs) inputs_embeds = self.get_input_embeddings(input_ids, vision_embeddings) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 146d48e522119..9224687d8a5d3 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -49,7 +49,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, PromptInsertion, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors -from vllm.utils import flatten_2d_lists from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsQuant) @@ -72,17 +71,17 @@ POOLING_SIZE = 2 class MolmoImageInputs(TypedDict): images: Union[torch.Tensor, list[torch.Tensor]] - """Shape: `(batch_size, num_crops, num_patch, patch_dim)`""" + """Shape: `(batch_size * num_images, num_crops, num_patch, patch_dim)`""" image_masks: Optional[Union[torch.Tensor, list[torch.Tensor]]] - """Shape: `(batch_size, num_crops, num_patch)`""" + """Shape: `(batch_size * num_images, num_crops, num_patch)`""" feat_is_patch: Union[torch.Tensor, list[torch.Tensor]] """ A boolean mask indicating which image features correspond to patch tokens. - Shape: `(batch_size, num_crops, num_patch)` + Shape: `(batch_size * num_images, num_crops, num_patch)` """ embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] @@ -90,7 +89,7 @@ class MolmoImageInputs(TypedDict): A boolean mask indicating which image embeddings correspond to patch tokens. - Shape: `(batch_size, num_embeds)` + Shape: `(batch_size * num_images, num_embeds)` """ num_crops: Union[torch.Tensor, list[torch.Tensor]] @@ -696,9 +695,10 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant): return image_features def forward( - self, images: torch.Tensor, image_masks: torch.Tensor - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - + self, + images: torch.Tensor, + image_masks: torch.Tensor, + ) -> torch.Tensor: # image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim) # noqa: E501 batch_size, num_image = images.shape[:2] images = images.to(device=self.device, dtype=self.dtype) @@ -1491,6 +1491,8 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, f"Got type: {type(img_patch_id)}") self.img_patch_id = img_patch_id.flatten().unique().item() + embed_is_patch = flatten_bn(embed_is_patch) + return MolmoImageInputs( images=images, image_masks=image_masks, @@ -1502,13 +1504,17 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, def _process_image_input( self, image_input: MolmoImageInputs, - ) -> Union[torch.Tensor, list[torch.Tensor]]: - if isinstance(image_input["images"], list): + ) -> list[torch.Tensor]: + images = image_input["images"] + image_masks = image_input["image_masks"] + feat_is_patch = image_input["feat_is_patch"] + num_crops = image_input["num_crops"] + + if isinstance(images, list): # Call the vision backbone on the whole batch at once - images_flat = flatten_bn(image_input["images"], concat=True) - image_masks_flat = (None if (image_masks := - image_input["image_masks"]) is None - else flatten_bn(image_masks, concat=True)) + images_flat = flatten_bn(images, concat=True) + image_masks_flat = (None if image_masks is None else flatten_bn( + image_masks, concat=True)) image_features_flat = self.vision_backbone( images=images_flat.unsqueeze(0), @@ -1517,63 +1523,19 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ).squeeze(0) # Reconstruct the batch dimension - image_features = image_features_flat.split( - image_input["num_crops"].sum(-1).tolist()) + num_crops_per_image = [nc.sum().item() for nc in num_crops] + image_features = image_features_flat.split(num_crops_per_image) else: image_features = self.vision_backbone( - images=image_input["images"], - image_masks=image_input["image_masks"], + images=images, + image_masks=image_masks, ) - return image_features - - def _get_mm_embeds( - self, - features: torch.Tensor, # Shape: (num_crop, num_patch, d) - feat_is_patch: torch.Tensor, # Shape: (num_crop, num_patch) - num_crops: torch.Tensor, # Shape: (num_images,) - embed_is_patch: torch.Tensor, # Shape: (num_embeds,) - ) -> tuple[torch.Tensor, ...]: - """ - Scatter the patch features into a contiguous tensor that corresponds - to the embedding tokens defined by the multimodal processor. - - Note: - The original code only considers patch tokens as feature - tokens, but our processor considers all image-related tokens - as feature tokens because the feature tokens need to be - consecutive in `input_ids`. - - Example: - A simplified example for one item in the batch: - - .. code-block:: - - Embedding tokens (from HF processor): - [ ] - - embed_is_patch (from HF processor): - [ False True True False True True False False ] - - Encoder outputs (from model): - [ p1 p2 0 p3 p4 0 ] - - feat_is_patch (from HF processor): - [ True True False True True False ] - - The resulting embedding tensor is: - [ nan p1 p2 nan p3 p4 nan nan ] - """ - num_crops_per_image = num_crops.tolist() - feats_per_image = features.split(num_crops_per_image) - f_is_patch_per_image = feat_is_patch.split(num_crops_per_image) - - features = torch.cat([ + # Only the features corresponding to patch tokens are relevant + return [ feats[f_is_patch] - for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image) - ]) - - return scatter_patch_features(features, embed_is_patch) + for feats, f_is_patch in zip(image_features, feat_is_patch) + ] def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: @@ -1583,13 +1545,10 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, image_features = self._process_image_input(image_input) - return flatten_2d_lists( - self._get_mm_embeds(*args) for args in zip( - image_features, - image_input["feat_is_patch"], - image_input["num_crops"], - image_input["embed_is_patch"], - )) + return scatter_patch_features( + image_features, + image_input["embed_is_patch"], + ) def get_input_embeddings( self, diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index a3ad360961243..da2017c987d4f 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -42,7 +42,6 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import (MistralTokenizer, cached_tokenizer_from_config) -from vllm.utils import flatten_2d_lists from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix, @@ -74,7 +73,7 @@ class PixtralImagePixelInputs(TypedDict): A boolean mask indicating which image embeddings correspond to patch tokens. - Shape: `(batch_size, num_images, num_embeds)` + Shape: `(batch_size * num_images, num_embeds)` """ @@ -387,6 +386,8 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, raise ValueError("Incorrect type of embed_is_patch. " f"Got type: {type(embed_is_patch)}") + embed_is_patch = flatten_bn(embed_is_patch) + return PixtralImagePixelInputs( type="pixel_values", images=flatten_bn(images), @@ -428,14 +429,10 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, image_features = self._process_image_input(image_input) - if kwargs.get("v0_path", False): - return image_features - - return flatten_2d_lists( - scatter_patch_features(*args) for args in zip( - image_features, - image_input["embed_is_patch"], - )) + return scatter_patch_features( + image_features, + image_input["embed_is_patch"], + ) def get_input_embeddings( self, @@ -467,7 +464,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, # NOTE: In v1, inputs_embeds is always generated at model runner, this # condition is for v0 compatibility. elif inputs_embeds is None: - kwargs.update({"v0_path": True}) vision_embeddings = self.get_multimodal_embeddings(**kwargs) inputs_embeds = self.get_input_embeddings(input_ids, vision_embeddings) diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index c91459398308e..db069f8de2a35 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod +from collections.abc import Sequence from typing import Final, Generic, Optional, Protocol, TypeVar, Union, cast import torch @@ -154,8 +155,8 @@ def resolve_visual_encoder_outputs( def scatter_patch_features( - features: torch.Tensor, - embed_is_patch: Union[torch.Tensor, list[torch.Tensor]], + patches: Union[torch.Tensor, Sequence[torch.Tensor]], + embed_is_patch: Union[torch.Tensor, Sequence[torch.Tensor]], ) -> tuple[torch.Tensor, ...]: """ Scatter the patch features into a contiguous tensor that corresponds @@ -165,8 +166,8 @@ def scatter_patch_features( can be filtered out by :func`select_patch_features`. Args: - features: The patch features, concatenated across each image. - Shape: `(num_patch, feature_depth)` + patches: The patch features for each image. + Shape: `(num_images, , feature_depth)` embed_is_patch: A boolean mask indicating which image embeddings correspond to patch tokens for each image. Shape: `(num_images, num_embeds)` @@ -194,21 +195,21 @@ def scatter_patch_features( The resulting embedding tensor is: [ nan p1 p2 nan p3 p4 nan nan ] """ - num_embeds_per_image = [ - e_is_patch.numel() for e_is_patch in embed_is_patch - ] - if isinstance(embed_is_patch, torch.Tensor): - embed_is_patch_flat = embed_is_patch.view(-1) - else: - embed_is_patch_flat = torch.cat(embed_is_patch) + if len(patches) != len(embed_is_patch): + raise ValueError(f"Inconsistent num_images: {len(patches)=} vs. " + f"{len(embed_is_patch)=}") - embeds_flat = features.new_full( - (sum(num_embeds_per_image), features.shape[-1]), - fill_value=torch.nan, - ) - embeds_flat[embed_is_patch_flat] = features.flatten(0, -2) + def get_embed_one(patches_one: torch.Tensor, e_is_patch: torch.Tensor): + embed_one = patches_one.new_full( + (e_is_patch.shape[0], patches_one.shape[-1]), + fill_value=torch.nan, + ) + embed_one[e_is_patch] = patches_one.flatten(0, -2) + return embed_one - return embeds_flat.split(num_embeds_per_image) + return tuple( + get_embed_one(patches_one, e_is_patch) + for patches_one, e_is_patch in zip(patches, embed_is_patch)) def select_patch_features(