[Misc] Clean up scatter_patch_features (#15559)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-03-27 15:45:00 +08:00 committed by GitHub
parent 43ed4143c4
commit e6c9053f9e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 82 additions and 136 deletions

View File

@ -30,7 +30,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
# yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
@ -60,7 +59,7 @@ class Gemma3ImagePixelInputs(TypedDict):
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size, num_images, num_embeds)`
Shape: `(batch_size * num_images, num_embeds)`
"""
@ -593,6 +592,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
pixel_values = flatten_bn(pixel_values, concat=True)
num_crops = flatten_bn(num_crops, concat=True)
embed_is_patch = flatten_bn(embed_is_patch)
return Gemma3ImagePixelInputs(
type="pixel_values",
@ -635,14 +635,10 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
image_features = self._process_image_input(image_input)
if kwargs.get("v0_path", False):
return image_features
return flatten_2d_lists(
scatter_patch_features(*args) for args in zip(
image_features,
image_input["embed_is_patch"],
))
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings(
self,
@ -671,7 +667,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
kwargs.update({"v0_path": True})
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,

View File

@ -35,7 +35,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import flatten_2d_lists
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
@ -66,13 +65,13 @@ class InternVLImagePixelInputs(TypedDict):
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size, num_images, num_embeds)`
Shape: `(batch_size * num_images, num_embeds)`
"""
class InternVLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: NestedTensors
data: Union[torch.Tensor, list[torch.Tensor]]
"""
A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
or a list of tensors of shape `(total_image_feature_size, hidden_size)`
@ -867,6 +866,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
image_num_patches = flatten_bn(image_num_patches, concat=True)
embed_is_patch = flatten_bn(embed_is_patch)
return InternVLImagePixelInputs(
type="pixel_values",
@ -881,7 +881,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
def _process_image_input(
self,
image_input: InternVLImageInputs,
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]]:
if image_input["type"] == "image_embeds":
return image_input["data"]
@ -921,15 +921,13 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
image_features = self._process_image_input(image_input)
if (kwargs.get("v0_path", False)
or image_input["type"] != "pixel_values"):
if image_input["type"] != "pixel_values":
return image_features
return flatten_2d_lists(
scatter_patch_features(*args) for args in zip(
image_features,
image_input["embed_is_patch"],
))
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings(
self,
@ -964,7 +962,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
kwargs.update({"v0_path": True})
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)

View File

@ -35,7 +35,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists
from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
@ -73,7 +72,7 @@ class PixtralHFImagePixelInputs(TypedDict):
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size, num_images, num_embeds)`
Shape: `(batch_size * num_images, num_embeds)`
"""
@ -618,6 +617,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
embed_is_patch = flatten_bn(embed_is_patch)
return PixtralHFImagePixelInputs(
type="pixel_values_pixtral",
pixel_values=flatten_bn(pixel_values),
@ -713,18 +714,16 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
image_features = self._process_image_input(image_input)
if (kwargs.get("v0_path", False)
or image_input["type"] != "pixel_values_pixtral"):
if image_input["type"] != "pixel_values_pixtral":
# The path is used for pixtral (V0 only) and llava (V0/V1)
return vision_embeddings
return image_features
return flatten_2d_lists(
scatter_patch_features(*args) for args in zip(
vision_embeddings,
image_input["embed_is_patch"],
))
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings(
self,
@ -790,7 +789,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
kwargs.update({"v0_path": True})
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)

View File

@ -49,7 +49,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptInsertion, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP, SupportsQuant)
@ -72,17 +71,17 @@ POOLING_SIZE = 2
class MolmoImageInputs(TypedDict):
images: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_crops, num_patch, patch_dim)`"""
"""Shape: `(batch_size * num_images, num_crops, num_patch, patch_dim)`"""
image_masks: Optional[Union[torch.Tensor, list[torch.Tensor]]]
"""Shape: `(batch_size, num_crops, num_patch)`"""
"""Shape: `(batch_size * num_images, num_crops, num_patch)`"""
feat_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image features correspond
to patch tokens.
Shape: `(batch_size, num_crops, num_patch)`
Shape: `(batch_size * num_images, num_crops, num_patch)`
"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
@ -90,7 +89,7 @@ class MolmoImageInputs(TypedDict):
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size, num_embeds)`
Shape: `(batch_size * num_images, num_embeds)`
"""
num_crops: Union[torch.Tensor, list[torch.Tensor]]
@ -696,9 +695,10 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant):
return image_features
def forward(
self, images: torch.Tensor, image_masks: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
self,
images: torch.Tensor,
image_masks: torch.Tensor,
) -> torch.Tensor:
# image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim) # noqa: E501
batch_size, num_image = images.shape[:2]
images = images.to(device=self.device, dtype=self.dtype)
@ -1491,6 +1491,8 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
f"Got type: {type(img_patch_id)}")
self.img_patch_id = img_patch_id.flatten().unique().item()
embed_is_patch = flatten_bn(embed_is_patch)
return MolmoImageInputs(
images=images,
image_masks=image_masks,
@ -1502,13 +1504,17 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
def _process_image_input(
self,
image_input: MolmoImageInputs,
) -> Union[torch.Tensor, list[torch.Tensor]]:
if isinstance(image_input["images"], list):
) -> list[torch.Tensor]:
images = image_input["images"]
image_masks = image_input["image_masks"]
feat_is_patch = image_input["feat_is_patch"]
num_crops = image_input["num_crops"]
if isinstance(images, list):
# Call the vision backbone on the whole batch at once
images_flat = flatten_bn(image_input["images"], concat=True)
image_masks_flat = (None if (image_masks :=
image_input["image_masks"]) is None
else flatten_bn(image_masks, concat=True))
images_flat = flatten_bn(images, concat=True)
image_masks_flat = (None if image_masks is None else flatten_bn(
image_masks, concat=True))
image_features_flat = self.vision_backbone(
images=images_flat.unsqueeze(0),
@ -1517,63 +1523,19 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
).squeeze(0)
# Reconstruct the batch dimension
image_features = image_features_flat.split(
image_input["num_crops"].sum(-1).tolist())
num_crops_per_image = [nc.sum().item() for nc in num_crops]
image_features = image_features_flat.split(num_crops_per_image)
else:
image_features = self.vision_backbone(
images=image_input["images"],
image_masks=image_input["image_masks"],
images=images,
image_masks=image_masks,
)
return image_features
def _get_mm_embeds(
self,
features: torch.Tensor, # Shape: (num_crop, num_patch, d)
feat_is_patch: torch.Tensor, # Shape: (num_crop, num_patch)
num_crops: torch.Tensor, # Shape: (num_images,)
embed_is_patch: torch.Tensor, # Shape: (num_embeds,)
) -> tuple[torch.Tensor, ...]:
"""
Scatter the patch features into a contiguous tensor that corresponds
to the embedding tokens defined by the multimodal processor.
Note:
The original code only considers patch tokens as feature
tokens, but our processor considers all image-related tokens
as feature tokens because the feature tokens need to be
consecutive in `input_ids`.
Example:
A simplified example for one item in the batch:
.. code-block::
Embedding tokens (from HF processor):
[<start> <patch> <patch> <col> <patch> <patch> <col> <end> ]
embed_is_patch (from HF processor):
[ False True True False True True False False ]
Encoder outputs (from model):
[ p1 p2 0 p3 p4 0 ]
feat_is_patch (from HF processor):
[ True True False True True False ]
The resulting embedding tensor is:
[ nan p1 p2 nan p3 p4 nan nan ]
"""
num_crops_per_image = num_crops.tolist()
feats_per_image = features.split(num_crops_per_image)
f_is_patch_per_image = feat_is_patch.split(num_crops_per_image)
features = torch.cat([
# Only the features corresponding to patch tokens are relevant
return [
feats[f_is_patch]
for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image)
])
return scatter_patch_features(features, embed_is_patch)
for feats, f_is_patch in zip(image_features, feat_is_patch)
]
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
@ -1583,13 +1545,10 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
image_features = self._process_image_input(image_input)
return flatten_2d_lists(
self._get_mm_embeds(*args) for args in zip(
image_features,
image_input["feat_is_patch"],
image_input["num_crops"],
image_input["embed_is_patch"],
))
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings(
self,

View File

@ -42,7 +42,6 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
cached_tokenizer_from_config)
from vllm.utils import flatten_2d_lists
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
@ -74,7 +73,7 @@ class PixtralImagePixelInputs(TypedDict):
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size, num_images, num_embeds)`
Shape: `(batch_size * num_images, num_embeds)`
"""
@ -387,6 +386,8 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
embed_is_patch = flatten_bn(embed_is_patch)
return PixtralImagePixelInputs(
type="pixel_values",
images=flatten_bn(images),
@ -428,14 +429,10 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
image_features = self._process_image_input(image_input)
if kwargs.get("v0_path", False):
return image_features
return flatten_2d_lists(
scatter_patch_features(*args) for args in zip(
image_features,
image_input["embed_is_patch"],
))
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings(
self,
@ -467,7 +464,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
kwargs.update({"v0_path": True})
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Final, Generic, Optional, Protocol, TypeVar, Union, cast
import torch
@ -154,8 +155,8 @@ def resolve_visual_encoder_outputs(
def scatter_patch_features(
features: torch.Tensor,
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]],
patches: Union[torch.Tensor, Sequence[torch.Tensor]],
embed_is_patch: Union[torch.Tensor, Sequence[torch.Tensor]],
) -> tuple[torch.Tensor, ...]:
"""
Scatter the patch features into a contiguous tensor that corresponds
@ -165,8 +166,8 @@ def scatter_patch_features(
can be filtered out by :func`select_patch_features`.
Args:
features: The patch features, concatenated across each image.
Shape: `(num_patch, feature_depth)`
patches: The patch features for each image.
Shape: `(num_images, <patch_dims>, feature_depth)`
embed_is_patch: A boolean mask indicating which image embeddings
correspond to patch tokens for each image.
Shape: `(num_images, num_embeds)`
@ -194,21 +195,21 @@ def scatter_patch_features(
The resulting embedding tensor is:
[ nan p1 p2 nan p3 p4 nan nan ]
"""
num_embeds_per_image = [
e_is_patch.numel() for e_is_patch in embed_is_patch
]
if isinstance(embed_is_patch, torch.Tensor):
embed_is_patch_flat = embed_is_patch.view(-1)
else:
embed_is_patch_flat = torch.cat(embed_is_patch)
if len(patches) != len(embed_is_patch):
raise ValueError(f"Inconsistent num_images: {len(patches)=} vs. "
f"{len(embed_is_patch)=}")
embeds_flat = features.new_full(
(sum(num_embeds_per_image), features.shape[-1]),
fill_value=torch.nan,
)
embeds_flat[embed_is_patch_flat] = features.flatten(0, -2)
def get_embed_one(patches_one: torch.Tensor, e_is_patch: torch.Tensor):
embed_one = patches_one.new_full(
(e_is_patch.shape[0], patches_one.shape[-1]),
fill_value=torch.nan,
)
embed_one[e_is_patch] = patches_one.flatten(0, -2)
return embed_one
return embeds_flat.split(num_embeds_per_image)
return tuple(
get_embed_one(patches_one, e_is_patch)
for patches_one, e_is_patch in zip(patches, embed_is_patch))
def select_patch_features(