mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 20:35:01 +08:00
[Bugfix] Added embed_is_patch mask for fuyu model (#15731)
Signed-off-by: Kyle Huang <kylhuang@nvidia.com>
This commit is contained in:
parent
248e76c4df
commit
bb103b29bf
@ -18,7 +18,7 @@
|
|||||||
""" PyTorch Fuyu model."""
|
""" PyTorch Fuyu model."""
|
||||||
import math
|
import math
|
||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from typing import Literal, Optional, Set, Tuple, TypedDict
|
from typing import Literal, Optional, Set, Tuple, TypedDict, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -39,10 +39,12 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
PromptUpdate, PromptUpdateDetails)
|
PromptUpdate, PromptUpdateDetails)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.utils import flatten_2d_lists
|
||||||
|
|
||||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||||
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
||||||
merge_multimodal_embeddings)
|
merge_multimodal_embeddings)
|
||||||
|
from .vision import scatter_patch_features, select_patch_features
|
||||||
|
|
||||||
# Cannot find the following 2 numbers from hf config.
|
# Cannot find the following 2 numbers from hf config.
|
||||||
_IMAGE_TOKEN_ID = 71011
|
_IMAGE_TOKEN_ID = 71011
|
||||||
@ -64,6 +66,11 @@ class FuyuImagePatchInputs(TypedDict):
|
|||||||
This is used to split the embeddings which has the first two dimensions
|
This is used to split the embeddings which has the first two dimensions
|
||||||
flattened just like `flat_data`.
|
flattened just like `flat_data`.
|
||||||
"""
|
"""
|
||||||
|
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||||
|
"""
|
||||||
|
A boolean mask indicating which image embeddings correspond
|
||||||
|
to patch tokens.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class FuyuProcessingInfo(BaseProcessingInfo):
|
class FuyuProcessingInfo(BaseProcessingInfo):
|
||||||
@ -183,6 +190,19 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
|
|||||||
|
|
||||||
processed_outputs["image_patches"] = image_patches[0]
|
processed_outputs["image_patches"] = image_patches[0]
|
||||||
|
|
||||||
|
# get patch grid size for each image
|
||||||
|
embed_is_patch = []
|
||||||
|
for image in images:
|
||||||
|
ncols, nrows = self.info.get_image_feature_grid_size(
|
||||||
|
image_width=image.width,
|
||||||
|
image_height=image.height,
|
||||||
|
)
|
||||||
|
|
||||||
|
mask = torch.tensor(([True] * ncols + [False]) * nrows)
|
||||||
|
embed_is_patch.append(mask)
|
||||||
|
|
||||||
|
processed_outputs["embed_is_patch"] = embed_is_patch
|
||||||
|
|
||||||
return processed_outputs
|
return processed_outputs
|
||||||
|
|
||||||
def _apply_hf_processor_tokens_only(
|
def _apply_hf_processor_tokens_only(
|
||||||
@ -202,7 +222,8 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
|
|||||||
hf_inputs: BatchFeature,
|
hf_inputs: BatchFeature,
|
||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
) -> Mapping[str, MultiModalFieldConfig]:
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
return dict(image_patches=MultiModalFieldConfig.batched("image"))
|
return dict(image_patches=MultiModalFieldConfig.batched("image"),
|
||||||
|
embed_is_patch=MultiModalFieldConfig.batched("image"))
|
||||||
|
|
||||||
def _get_prompt_updates(
|
def _get_prompt_updates(
|
||||||
self,
|
self,
|
||||||
@ -301,11 +322,15 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
def _parse_and_validate_image_input(
|
def _parse_and_validate_image_input(
|
||||||
self, **kwargs: object) -> Optional[FuyuImagePatchInputs]:
|
self, **kwargs: object) -> Optional[FuyuImagePatchInputs]:
|
||||||
image_patches = kwargs.pop("image_patches", None)
|
image_patches = kwargs.pop("image_patches", None)
|
||||||
|
embed_is_patch = kwargs.pop("embed_is_patch", None)
|
||||||
if image_patches is not None:
|
if image_patches is not None:
|
||||||
if not isinstance(image_patches, (torch.Tensor, list)):
|
if not isinstance(image_patches, (torch.Tensor, list)):
|
||||||
raise ValueError("Incorrect type of image patches. "
|
raise ValueError("Incorrect type of image patches. "
|
||||||
f"Got type: {type(image_patches)}")
|
f"Got type: {type(image_patches)}")
|
||||||
|
|
||||||
|
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||||
|
raise ValueError("Incorrect type of embed_is_patch. "
|
||||||
|
f"Got type: {type(embed_is_patch)}")
|
||||||
image_patches_flat = flatten_bn(image_patches)
|
image_patches_flat = flatten_bn(image_patches)
|
||||||
|
|
||||||
return FuyuImagePatchInputs(
|
return FuyuImagePatchInputs(
|
||||||
@ -313,6 +338,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
flat_data=self._validate_pixel_values(
|
flat_data=self._validate_pixel_values(
|
||||||
flatten_bn(image_patches_flat, concat=True)),
|
flatten_bn(image_patches_flat, concat=True)),
|
||||||
patches_per_image=[x.size(0) for x in image_patches_flat],
|
patches_per_image=[x.size(0) for x in image_patches_flat],
|
||||||
|
embed_is_patch=embed_is_patch,
|
||||||
)
|
)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
@ -333,7 +359,12 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
if image_input is None:
|
if image_input is None:
|
||||||
return None
|
return None
|
||||||
vision_embeddings = self._process_image_input(image_input)
|
vision_embeddings = self._process_image_input(image_input)
|
||||||
return vision_embeddings
|
#return vision_embeddings
|
||||||
|
return flatten_2d_lists(
|
||||||
|
scatter_patch_features(*args) for args in zip(
|
||||||
|
vision_embeddings,
|
||||||
|
image_input["embed_is_patch"],
|
||||||
|
))
|
||||||
|
|
||||||
def get_input_embeddings(
|
def get_input_embeddings(
|
||||||
self,
|
self,
|
||||||
@ -343,8 +374,8 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||||
if multimodal_embeddings is not None:
|
if multimodal_embeddings is not None:
|
||||||
inputs_embeds = merge_multimodal_embeddings(
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
input_ids, inputs_embeds, multimodal_embeddings,
|
input_ids, inputs_embeds,
|
||||||
_IMAGE_TOKEN_ID)
|
select_patch_features(multimodal_embeddings), _IMAGE_TOKEN_ID)
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user