diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index bd7ef29e1f63..a1004cd0ac60 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -18,7 +18,7 @@ """ PyTorch Fuyu model.""" import math from collections.abc import Iterable, Mapping, Sequence -from typing import Literal, Optional, Set, Tuple, TypedDict +from typing import Literal, Optional, Set, Tuple, TypedDict, Union import torch import torch.nn as nn @@ -39,10 +39,12 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors +from vllm.utils import flatten_2d_lists from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, merge_multimodal_embeddings) +from .vision import scatter_patch_features, select_patch_features # Cannot find the following 2 numbers from hf config. _IMAGE_TOKEN_ID = 71011 @@ -64,6 +66,11 @@ class FuyuImagePatchInputs(TypedDict): This is used to split the embeddings which has the first two dimensions flattened just like `flat_data`. """ + embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] + """ + A boolean mask indicating which image embeddings correspond + to patch tokens. + """ class FuyuProcessingInfo(BaseProcessingInfo): @@ -183,6 +190,19 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]): processed_outputs["image_patches"] = image_patches[0] + # get patch grid size for each image + embed_is_patch = [] + for image in images: + ncols, nrows = self.info.get_image_feature_grid_size( + image_width=image.width, + image_height=image.height, + ) + + mask = torch.tensor(([True] * ncols + [False]) * nrows) + embed_is_patch.append(mask) + + processed_outputs["embed_is_patch"] = embed_is_patch + return processed_outputs def _apply_hf_processor_tokens_only( @@ -202,7 +222,8 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]): hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - return dict(image_patches=MultiModalFieldConfig.batched("image")) + return dict(image_patches=MultiModalFieldConfig.batched("image"), + embed_is_patch=MultiModalFieldConfig.batched("image")) def _get_prompt_updates( self, @@ -301,11 +322,15 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[FuyuImagePatchInputs]: image_patches = kwargs.pop("image_patches", None) + embed_is_patch = kwargs.pop("embed_is_patch", None) if image_patches is not None: if not isinstance(image_patches, (torch.Tensor, list)): raise ValueError("Incorrect type of image patches. " f"Got type: {type(image_patches)}") + if not isinstance(embed_is_patch, (torch.Tensor, list)): + raise ValueError("Incorrect type of embed_is_patch. " + f"Got type: {type(embed_is_patch)}") image_patches_flat = flatten_bn(image_patches) return FuyuImagePatchInputs( @@ -313,6 +338,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): flat_data=self._validate_pixel_values( flatten_bn(image_patches_flat, concat=True)), patches_per_image=[x.size(0) for x in image_patches_flat], + embed_is_patch=embed_is_patch, ) return None @@ -333,7 +359,12 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): if image_input is None: return None vision_embeddings = self._process_image_input(image_input) - return vision_embeddings + #return vision_embeddings + return flatten_2d_lists( + scatter_patch_features(*args) for args in zip( + vision_embeddings, + image_input["embed_is_patch"], + )) def get_input_embeddings( self, @@ -343,8 +374,8 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): inputs_embeds = self.language_model.get_input_embeddings(input_ids) if multimodal_embeddings is not None: inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - _IMAGE_TOKEN_ID) + input_ids, inputs_embeds, + select_patch_features(multimodal_embeddings), _IMAGE_TOKEN_ID) return inputs_embeds def forward(