mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-13 16:27:04 +08:00
[Misc] Clean up scatter_patch_features (#15559)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
43ed4143c4
commit
e6c9053f9e
@ -30,7 +30,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
# yapf: enable
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import flatten_2d_lists
|
||||
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
@ -60,7 +59,7 @@ class Gemma3ImagePixelInputs(TypedDict):
|
||||
A boolean mask indicating which image embeddings correspond
|
||||
to patch tokens.
|
||||
|
||||
Shape: `(batch_size, num_images, num_embeds)`
|
||||
Shape: `(batch_size * num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
|
||||
@ -593,6 +592,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
|
||||
pixel_values = flatten_bn(pixel_values, concat=True)
|
||||
num_crops = flatten_bn(num_crops, concat=True)
|
||||
embed_is_patch = flatten_bn(embed_is_patch)
|
||||
|
||||
return Gemma3ImagePixelInputs(
|
||||
type="pixel_values",
|
||||
@ -635,14 +635,10 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
|
||||
image_features = self._process_image_input(image_input)
|
||||
|
||||
if kwargs.get("v0_path", False):
|
||||
return image_features
|
||||
|
||||
return flatten_2d_lists(
|
||||
scatter_patch_features(*args) for args in zip(
|
||||
image_features,
|
||||
image_input["embed_is_patch"],
|
||||
))
|
||||
return scatter_patch_features(
|
||||
image_features,
|
||||
image_input["embed_is_patch"],
|
||||
)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
@ -671,7 +667,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
kwargs.update({"v0_path": True})
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
|
||||
@ -35,7 +35,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import flatten_2d_lists
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
@ -66,13 +65,13 @@ class InternVLImagePixelInputs(TypedDict):
|
||||
A boolean mask indicating which image embeddings correspond
|
||||
to patch tokens.
|
||||
|
||||
Shape: `(batch_size, num_images, num_embeds)`
|
||||
Shape: `(batch_size * num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
|
||||
class InternVLImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
data: NestedTensors
|
||||
data: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
|
||||
or a list of tensors of shape `(total_image_feature_size, hidden_size)`
|
||||
@ -867,6 +866,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
|
||||
image_num_patches = flatten_bn(image_num_patches, concat=True)
|
||||
embed_is_patch = flatten_bn(embed_is_patch)
|
||||
|
||||
return InternVLImagePixelInputs(
|
||||
type="pixel_values",
|
||||
@ -881,7 +881,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
def _process_image_input(
|
||||
self,
|
||||
image_input: InternVLImageInputs,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||
) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]]:
|
||||
if image_input["type"] == "image_embeds":
|
||||
return image_input["data"]
|
||||
|
||||
@ -921,15 +921,13 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
image_features = self._process_image_input(image_input)
|
||||
|
||||
if (kwargs.get("v0_path", False)
|
||||
or image_input["type"] != "pixel_values"):
|
||||
if image_input["type"] != "pixel_values":
|
||||
return image_features
|
||||
|
||||
return flatten_2d_lists(
|
||||
scatter_patch_features(*args) for args in zip(
|
||||
image_features,
|
||||
image_input["embed_is_patch"],
|
||||
))
|
||||
return scatter_patch_features(
|
||||
image_features,
|
||||
image_input["embed_is_patch"],
|
||||
)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
@ -964,7 +962,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
kwargs.update({"v0_path": True})
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
|
||||
@ -35,7 +35,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
PromptReplacement, PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import flatten_2d_lists
|
||||
|
||||
from .clip import CLIPVisionModel
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
@ -73,7 +72,7 @@ class PixtralHFImagePixelInputs(TypedDict):
|
||||
A boolean mask indicating which image embeddings correspond
|
||||
to patch tokens.
|
||||
|
||||
Shape: `(batch_size, num_images, num_embeds)`
|
||||
Shape: `(batch_size * num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
|
||||
@ -618,6 +617,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
raise ValueError("Incorrect type of embed_is_patch. "
|
||||
f"Got type: {type(embed_is_patch)}")
|
||||
|
||||
embed_is_patch = flatten_bn(embed_is_patch)
|
||||
|
||||
return PixtralHFImagePixelInputs(
|
||||
type="pixel_values_pixtral",
|
||||
pixel_values=flatten_bn(pixel_values),
|
||||
@ -713,18 +714,16 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
if image_input is None:
|
||||
return None
|
||||
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
image_features = self._process_image_input(image_input)
|
||||
|
||||
if (kwargs.get("v0_path", False)
|
||||
or image_input["type"] != "pixel_values_pixtral"):
|
||||
if image_input["type"] != "pixel_values_pixtral":
|
||||
# The path is used for pixtral (V0 only) and llava (V0/V1)
|
||||
return vision_embeddings
|
||||
return image_features
|
||||
|
||||
return flatten_2d_lists(
|
||||
scatter_patch_features(*args) for args in zip(
|
||||
vision_embeddings,
|
||||
image_input["embed_is_patch"],
|
||||
))
|
||||
return scatter_patch_features(
|
||||
image_features,
|
||||
image_input["embed_is_patch"],
|
||||
)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
@ -790,7 +789,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
kwargs.update({"v0_path": True})
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
|
||||
@ -49,7 +49,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
PromptInsertion, PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import flatten_2d_lists
|
||||
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP, SupportsQuant)
|
||||
@ -72,17 +71,17 @@ POOLING_SIZE = 2
|
||||
|
||||
class MolmoImageInputs(TypedDict):
|
||||
images: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""Shape: `(batch_size, num_crops, num_patch, patch_dim)`"""
|
||||
"""Shape: `(batch_size * num_images, num_crops, num_patch, patch_dim)`"""
|
||||
|
||||
image_masks: Optional[Union[torch.Tensor, list[torch.Tensor]]]
|
||||
"""Shape: `(batch_size, num_crops, num_patch)`"""
|
||||
"""Shape: `(batch_size * num_images, num_crops, num_patch)`"""
|
||||
|
||||
feat_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which image features correspond
|
||||
to patch tokens.
|
||||
|
||||
Shape: `(batch_size, num_crops, num_patch)`
|
||||
Shape: `(batch_size * num_images, num_crops, num_patch)`
|
||||
"""
|
||||
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
@ -90,7 +89,7 @@ class MolmoImageInputs(TypedDict):
|
||||
A boolean mask indicating which image embeddings correspond
|
||||
to patch tokens.
|
||||
|
||||
Shape: `(batch_size, num_embeds)`
|
||||
Shape: `(batch_size * num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
num_crops: Union[torch.Tensor, list[torch.Tensor]]
|
||||
@ -696,9 +695,10 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant):
|
||||
return image_features
|
||||
|
||||
def forward(
|
||||
self, images: torch.Tensor, image_masks: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
|
||||
self,
|
||||
images: torch.Tensor,
|
||||
image_masks: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim) # noqa: E501
|
||||
batch_size, num_image = images.shape[:2]
|
||||
images = images.to(device=self.device, dtype=self.dtype)
|
||||
@ -1491,6 +1491,8 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
f"Got type: {type(img_patch_id)}")
|
||||
self.img_patch_id = img_patch_id.flatten().unique().item()
|
||||
|
||||
embed_is_patch = flatten_bn(embed_is_patch)
|
||||
|
||||
return MolmoImageInputs(
|
||||
images=images,
|
||||
image_masks=image_masks,
|
||||
@ -1502,13 +1504,17 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
def _process_image_input(
|
||||
self,
|
||||
image_input: MolmoImageInputs,
|
||||
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
||||
if isinstance(image_input["images"], list):
|
||||
) -> list[torch.Tensor]:
|
||||
images = image_input["images"]
|
||||
image_masks = image_input["image_masks"]
|
||||
feat_is_patch = image_input["feat_is_patch"]
|
||||
num_crops = image_input["num_crops"]
|
||||
|
||||
if isinstance(images, list):
|
||||
# Call the vision backbone on the whole batch at once
|
||||
images_flat = flatten_bn(image_input["images"], concat=True)
|
||||
image_masks_flat = (None if (image_masks :=
|
||||
image_input["image_masks"]) is None
|
||||
else flatten_bn(image_masks, concat=True))
|
||||
images_flat = flatten_bn(images, concat=True)
|
||||
image_masks_flat = (None if image_masks is None else flatten_bn(
|
||||
image_masks, concat=True))
|
||||
|
||||
image_features_flat = self.vision_backbone(
|
||||
images=images_flat.unsqueeze(0),
|
||||
@ -1517,63 +1523,19 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
).squeeze(0)
|
||||
|
||||
# Reconstruct the batch dimension
|
||||
image_features = image_features_flat.split(
|
||||
image_input["num_crops"].sum(-1).tolist())
|
||||
num_crops_per_image = [nc.sum().item() for nc in num_crops]
|
||||
image_features = image_features_flat.split(num_crops_per_image)
|
||||
else:
|
||||
image_features = self.vision_backbone(
|
||||
images=image_input["images"],
|
||||
image_masks=image_input["image_masks"],
|
||||
images=images,
|
||||
image_masks=image_masks,
|
||||
)
|
||||
|
||||
return image_features
|
||||
|
||||
def _get_mm_embeds(
|
||||
self,
|
||||
features: torch.Tensor, # Shape: (num_crop, num_patch, d)
|
||||
feat_is_patch: torch.Tensor, # Shape: (num_crop, num_patch)
|
||||
num_crops: torch.Tensor, # Shape: (num_images,)
|
||||
embed_is_patch: torch.Tensor, # Shape: (num_embeds,)
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
Scatter the patch features into a contiguous tensor that corresponds
|
||||
to the embedding tokens defined by the multimodal processor.
|
||||
|
||||
Note:
|
||||
The original code only considers patch tokens as feature
|
||||
tokens, but our processor considers all image-related tokens
|
||||
as feature tokens because the feature tokens need to be
|
||||
consecutive in `input_ids`.
|
||||
|
||||
Example:
|
||||
A simplified example for one item in the batch:
|
||||
|
||||
.. code-block::
|
||||
|
||||
Embedding tokens (from HF processor):
|
||||
[<start> <patch> <patch> <col> <patch> <patch> <col> <end> ]
|
||||
|
||||
embed_is_patch (from HF processor):
|
||||
[ False True True False True True False False ]
|
||||
|
||||
Encoder outputs (from model):
|
||||
[ p1 p2 0 p3 p4 0 ]
|
||||
|
||||
feat_is_patch (from HF processor):
|
||||
[ True True False True True False ]
|
||||
|
||||
The resulting embedding tensor is:
|
||||
[ nan p1 p2 nan p3 p4 nan nan ]
|
||||
"""
|
||||
num_crops_per_image = num_crops.tolist()
|
||||
feats_per_image = features.split(num_crops_per_image)
|
||||
f_is_patch_per_image = feat_is_patch.split(num_crops_per_image)
|
||||
|
||||
features = torch.cat([
|
||||
# Only the features corresponding to patch tokens are relevant
|
||||
return [
|
||||
feats[f_is_patch]
|
||||
for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image)
|
||||
])
|
||||
|
||||
return scatter_patch_features(features, embed_is_patch)
|
||||
for feats, f_is_patch in zip(image_features, feat_is_patch)
|
||||
]
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
@ -1583,13 +1545,10 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
|
||||
image_features = self._process_image_input(image_input)
|
||||
|
||||
return flatten_2d_lists(
|
||||
self._get_mm_embeds(*args) for args in zip(
|
||||
image_features,
|
||||
image_input["feat_is_patch"],
|
||||
image_input["num_crops"],
|
||||
image_input["embed_is_patch"],
|
||||
))
|
||||
return scatter_patch_features(
|
||||
image_features,
|
||||
image_input["embed_is_patch"],
|
||||
)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
|
||||
@ -42,7 +42,6 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
|
||||
cached_tokenizer_from_config)
|
||||
from vllm.utils import flatten_2d_lists
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
|
||||
@ -74,7 +73,7 @@ class PixtralImagePixelInputs(TypedDict):
|
||||
A boolean mask indicating which image embeddings correspond
|
||||
to patch tokens.
|
||||
|
||||
Shape: `(batch_size, num_images, num_embeds)`
|
||||
Shape: `(batch_size * num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
|
||||
@ -387,6 +386,8 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
raise ValueError("Incorrect type of embed_is_patch. "
|
||||
f"Got type: {type(embed_is_patch)}")
|
||||
|
||||
embed_is_patch = flatten_bn(embed_is_patch)
|
||||
|
||||
return PixtralImagePixelInputs(
|
||||
type="pixel_values",
|
||||
images=flatten_bn(images),
|
||||
@ -428,14 +429,10 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
image_features = self._process_image_input(image_input)
|
||||
|
||||
if kwargs.get("v0_path", False):
|
||||
return image_features
|
||||
|
||||
return flatten_2d_lists(
|
||||
scatter_patch_features(*args) for args in zip(
|
||||
image_features,
|
||||
image_input["embed_is_patch"],
|
||||
))
|
||||
return scatter_patch_features(
|
||||
image_features,
|
||||
image_input["embed_is_patch"],
|
||||
)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
@ -467,7 +464,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
kwargs.update({"v0_path": True})
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import Final, Generic, Optional, Protocol, TypeVar, Union, cast
|
||||
|
||||
import torch
|
||||
@ -154,8 +155,8 @@ def resolve_visual_encoder_outputs(
|
||||
|
||||
|
||||
def scatter_patch_features(
|
||||
features: torch.Tensor,
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]],
|
||||
patches: Union[torch.Tensor, Sequence[torch.Tensor]],
|
||||
embed_is_patch: Union[torch.Tensor, Sequence[torch.Tensor]],
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
Scatter the patch features into a contiguous tensor that corresponds
|
||||
@ -165,8 +166,8 @@ def scatter_patch_features(
|
||||
can be filtered out by :func`select_patch_features`.
|
||||
|
||||
Args:
|
||||
features: The patch features, concatenated across each image.
|
||||
Shape: `(num_patch, feature_depth)`
|
||||
patches: The patch features for each image.
|
||||
Shape: `(num_images, <patch_dims>, feature_depth)`
|
||||
embed_is_patch: A boolean mask indicating which image embeddings
|
||||
correspond to patch tokens for each image.
|
||||
Shape: `(num_images, num_embeds)`
|
||||
@ -194,21 +195,21 @@ def scatter_patch_features(
|
||||
The resulting embedding tensor is:
|
||||
[ nan p1 p2 nan p3 p4 nan nan ]
|
||||
"""
|
||||
num_embeds_per_image = [
|
||||
e_is_patch.numel() for e_is_patch in embed_is_patch
|
||||
]
|
||||
if isinstance(embed_is_patch, torch.Tensor):
|
||||
embed_is_patch_flat = embed_is_patch.view(-1)
|
||||
else:
|
||||
embed_is_patch_flat = torch.cat(embed_is_patch)
|
||||
if len(patches) != len(embed_is_patch):
|
||||
raise ValueError(f"Inconsistent num_images: {len(patches)=} vs. "
|
||||
f"{len(embed_is_patch)=}")
|
||||
|
||||
embeds_flat = features.new_full(
|
||||
(sum(num_embeds_per_image), features.shape[-1]),
|
||||
fill_value=torch.nan,
|
||||
)
|
||||
embeds_flat[embed_is_patch_flat] = features.flatten(0, -2)
|
||||
def get_embed_one(patches_one: torch.Tensor, e_is_patch: torch.Tensor):
|
||||
embed_one = patches_one.new_full(
|
||||
(e_is_patch.shape[0], patches_one.shape[-1]),
|
||||
fill_value=torch.nan,
|
||||
)
|
||||
embed_one[e_is_patch] = patches_one.flatten(0, -2)
|
||||
return embed_one
|
||||
|
||||
return embeds_flat.split(num_embeds_per_image)
|
||||
return tuple(
|
||||
get_embed_one(patches_one, e_is_patch)
|
||||
for patches_one, e_is_patch in zip(patches, embed_is_patch))
|
||||
|
||||
|
||||
def select_patch_features(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user