[Bugfix] Added embed_is_patch mask for fuyu model (#15731)

Signed-off-by: Kyle Huang <kylhuang@nvidia.com>
This commit is contained in:
kYLe 2025-03-30 05:45:08 -05:00 committed by GitHub
parent 248e76c4df
commit bb103b29bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -18,7 +18,7 @@
""" PyTorch Fuyu model.""" """ PyTorch Fuyu model."""
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, Set, Tuple, TypedDict from typing import Literal, Optional, Set, Tuple, TypedDict, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -39,10 +39,12 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptUpdate, PromptUpdateDetails) PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings) merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
# Cannot find the following 2 numbers from hf config. # Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011 _IMAGE_TOKEN_ID = 71011
@ -64,6 +66,11 @@ class FuyuImagePatchInputs(TypedDict):
This is used to split the embeddings which has the first two dimensions This is used to split the embeddings which has the first two dimensions
flattened just like `flat_data`. flattened just like `flat_data`.
""" """
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
"""
class FuyuProcessingInfo(BaseProcessingInfo): class FuyuProcessingInfo(BaseProcessingInfo):
@ -183,6 +190,19 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
processed_outputs["image_patches"] = image_patches[0] processed_outputs["image_patches"] = image_patches[0]
# get patch grid size for each image
embed_is_patch = []
for image in images:
ncols, nrows = self.info.get_image_feature_grid_size(
image_width=image.width,
image_height=image.height,
)
mask = torch.tensor(([True] * ncols + [False]) * nrows)
embed_is_patch.append(mask)
processed_outputs["embed_is_patch"] = embed_is_patch
return processed_outputs return processed_outputs
def _apply_hf_processor_tokens_only( def _apply_hf_processor_tokens_only(
@ -202,7 +222,8 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
hf_inputs: BatchFeature, hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
return dict(image_patches=MultiModalFieldConfig.batched("image")) return dict(image_patches=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"))
def _get_prompt_updates( def _get_prompt_updates(
self, self,
@ -301,11 +322,15 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[FuyuImagePatchInputs]: self, **kwargs: object) -> Optional[FuyuImagePatchInputs]:
image_patches = kwargs.pop("image_patches", None) image_patches = kwargs.pop("image_patches", None)
embed_is_patch = kwargs.pop("embed_is_patch", None)
if image_patches is not None: if image_patches is not None:
if not isinstance(image_patches, (torch.Tensor, list)): if not isinstance(image_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of image patches. " raise ValueError("Incorrect type of image patches. "
f"Got type: {type(image_patches)}") f"Got type: {type(image_patches)}")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
image_patches_flat = flatten_bn(image_patches) image_patches_flat = flatten_bn(image_patches)
return FuyuImagePatchInputs( return FuyuImagePatchInputs(
@ -313,6 +338,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
flat_data=self._validate_pixel_values( flat_data=self._validate_pixel_values(
flatten_bn(image_patches_flat, concat=True)), flatten_bn(image_patches_flat, concat=True)),
patches_per_image=[x.size(0) for x in image_patches_flat], patches_per_image=[x.size(0) for x in image_patches_flat],
embed_is_patch=embed_is_patch,
) )
return None return None
@ -333,7 +359,12 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
if image_input is None: if image_input is None:
return None return None
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_image_input(image_input)
return vision_embeddings #return vision_embeddings
return flatten_2d_lists(
scatter_patch_features(*args) for args in zip(
vision_embeddings,
image_input["embed_is_patch"],
))
def get_input_embeddings( def get_input_embeddings(
self, self,
@ -343,8 +374,8 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, input_ids, inputs_embeds,
_IMAGE_TOKEN_ID) select_patch_features(multimodal_embeddings), _IMAGE_TOKEN_ID)
return inputs_embeds return inputs_embeds
def forward( def forward(