[Bugfix] embed_is_patch for Idefics3 (#15696)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-03-28 23:27:52 +08:00 committed by GitHub
parent 3b00ff9138
commit 541d1df486
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 320 additions and 188 deletions

View File

@ -24,7 +24,6 @@
from typing import Iterable, Optional, Set, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from transformers import CohereConfig

View File

@ -17,16 +17,14 @@
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Dict, List, Literal, Optional, Set, Tuple, TypedDict, Union
from typing import Dict, Literal, Optional, Set, Tuple, TypedDict, Union
import torch
import torch.utils.checkpoint
from torch import nn
from transformers import (BatchFeature, Idefics3Config, Idefics3ImageProcessor,
Idefics3Processor)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
@ -35,13 +33,16 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.parse import ImageProcessorItems
from vllm.multimodal.parse import ImageProcessorItems, ImageSize
# yapf conflicts with isort for this block
# yapf: disable
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo,
MultiModalDataItems,
MultiModalFieldConfig,
PromptReplacement, PromptUpdate)
PromptReplacement, PromptUpdate,
encode_tokens)
# yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
@ -53,18 +54,28 @@ from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
from .llama import LlamaModel
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)
logger = init_logger(__name__)
from .vision import scatter_patch_features, select_patch_features
class Idefics3ImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
pixel_values: torch.Tensor
"""
Shape: `(batch_size * num_images * num_patches,
num_channels, height, width)`
"""
pixel_attention_mask: Optional[torch.BoolTensor]
pixel_attention_mask: torch.Tensor
num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class Idefics3ImageEmbeddingInputs(TypedDict):
@ -75,6 +86,14 @@ class Idefics3ImageEmbeddingInputs(TypedDict):
`hidden_size` must match the hidden size of language model backbone.
"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs]
@ -100,32 +119,14 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
hf_processor = self.get_hf_processor()
image_processor: Idefics3ImageProcessor = hf_processor.image_processor
grid_w, grid_h = self._get_image_feature_grid_size(
image_width=image_processor.size['longest_edge'],
image_height=image_processor.size['longest_edge'],
)
num_image_token = (grid_w * grid_h + 1) * hf_processor.image_seq_len
# Calculate Non-image-token length
# NOTE: <row_1_col_1> and <global-img> are special token for SmolVLM
# but not for Idefic3, so we need to tokenize them to get actual length.
tokenizer = self.get_tokenizer()
tile_token_len = len(tokenizer.tokenize("<row_1_col_1>"))
glob_token_len = len(tokenizer.tokenize(hf_processor.global_image_tag))
# linebreak and <fake_token_around_image> always cost 1 token
fake_token_len = lb_len = 1
non_image_token = (grid_w * grid_h) * (
tile_token_len + fake_token_len) + glob_token_len + (
grid_h + 1) * lb_len + fake_token_len
return {"image": num_image_token + non_image_token}
return {"image": self.get_max_image_tokens()}
def _resize_output_size(self,
*,
height: int,
width: int,
max_len: Optional[int] = None,
min_len: Optional[int] = 1,
min_len: int = 1,
max_size: Optional[int] = None) -> tuple[int, int]:
# Set default value for max_len if not provided
max_len = max(height, width) if max_len is None else max_len
@ -181,10 +182,13 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
*,
image_width: int,
image_height: int,
size: Optional[dict[str, object]] = None,
processor: Optional[Idefics3Processor],
) -> tuple[int, int]:
hf_processor = self.get_hf_processor(size=size)
image_processor: Idefics3ImageProcessor = hf_processor.image_processor
if processor is None:
processor = self.get_hf_processor()
image_processor: Idefics3ImageProcessor = processor.image_processor
max_image_size = image_processor.max_image_size['longest_edge']
size = image_processor.size['longest_edge']
assert size % max_image_size == 0, (
@ -204,6 +208,105 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
grid_h = grid_w = 0
return grid_w, grid_h
def get_num_patches(
self,
*,
image_width: int,
image_height: int,
processor: Optional[Idefics3Processor],
) -> int:
grid_w, grid_h = self._get_image_feature_grid_size(
image_width=image_width,
image_height=image_height,
processor=processor,
)
return grid_w * grid_h + 1
def get_image_repl(
self,
*,
image_width: int,
image_height: int,
processor: Optional[Idefics3Processor],
) -> str:
if processor is None:
processor = self.get_hf_processor()
image_token = processor.image_token.content
fake_image_token = processor.fake_image_token.content
global_img_token = processor.global_image_tag
image_seq_len = processor.image_seq_len
grid_placeholder = "<row_{n_h}_col_{n_w}>"
p_img = image_token * image_seq_len
global_img_placeholder = fake_image_token + global_img_token + p_img
tile_img_placeholder = fake_image_token + grid_placeholder + p_img
grid_w, grid_h = self._get_image_feature_grid_size(
image_width=image_width,
image_height=image_height,
processor=processor,
)
if grid_w == 0 and grid_h == 0:
return global_img_placeholder + fake_image_token
tiles_placeholder = list[str]()
for i in range(grid_h):
for j in range(grid_w):
placeholder_per_tile = tile_img_placeholder.format(n_h=i + 1,
n_w=j + 1)
tiles_placeholder.append(placeholder_per_tile)
# Add line break if it is the last tile in the row
if j == grid_w - 1:
tiles_placeholder.append("\n")
return "".join([
*tiles_placeholder,
"\n",
global_img_placeholder,
fake_image_token,
])
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
processor: Optional[Idefics3Processor],
) -> int:
tokenizer = self.get_tokenizer()
image_repl = self.get_image_repl(
image_width=image_width,
image_height=image_height,
processor=processor,
)
image_repl_tokens = encode_tokens(
tokenizer,
image_repl,
add_special_tokens=False,
)
return len(image_repl_tokens)
def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor()
image_processor: Idefics3ImageProcessor = processor.image_processor
return ImageSize(
width=image_processor.size["longest_edge"],
height=image_processor.size["longest_edge"],
)
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=None,
)
class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]
):
@ -217,7 +320,7 @@ class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]
hf_processor = self.info.get_hf_processor()
image_processor: Idefics3ImageProcessor = hf_processor.image_processor
longest_edge = image_processor.max_image_size['longest_edge']
image_token: str = hf_processor.image_token.content
image_token = hf_processor.image_token.content
mm_data = {
"image":
@ -241,26 +344,61 @@ class Idefics3MultiModalProcessor(
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
if mm_data:
# Text-only input not supported in composite processor
if not (images := mm_data.get("images", [])):
prompt_ids = self.info.get_tokenizer().encode(prompt)
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
processed_outputs = super()._call_hf_processor(
prompt, mm_data, mm_kwargs)
image_grids = [
self.info._get_image_feature_grid_size(
image_width=img.width,
image_height=img.height,
**mm_kwargs,
) for img in mm_data["images"]
prompt,
mm_data,
mm_kwargs,
)
parsed_images = (self._get_data_parser().parse_mm_data({
"image": images
}).get_items("image", ImageProcessorItems))
image_sizes = [
parsed_images.get_image_size(i) for i in range(len(parsed_images))
]
image_patches = list(map(lambda x: math.prod(x) + 1, image_grids))
for key in ("pixel_values", "pixel_attention_mask"):
data = processed_outputs.pop(key)
data = data.flatten(0, 1).split(image_patches)
processed_outputs[key] = data
else:
hf_processor = self.info.get_hf_processor(**mm_kwargs)
image_repl_features = [
self.info.get_image_repl(image_width=size.width,
image_height=size.height,
processor=hf_processor)
for size in image_sizes
]
tokenizer = self.info.get_tokenizer()
processed_outputs = tokenizer(prompt,
add_special_tokens=True,
return_tensors="pt")
image_repls_feature_tokens = [
tokenizer.encode(image_repl, add_special_tokens=False)
for image_repl in image_repl_features
]
vocab = tokenizer.get_vocab()
image_token_id = vocab[hf_processor.image_token.content]
embed_is_patch = [
torch.tensor(image_repl_tokens) == image_token_id
for image_repl_tokens in image_repls_feature_tokens
]
processed_outputs["embed_is_patch"] = embed_is_patch
num_patches = [
self.info.get_num_patches(
image_width=size.width,
image_height=size.height,
processor=hf_processor,
) for size in image_sizes
]
processed_outputs["num_patches"] = torch.tensor(num_patches)
# Remove the extra batch dimension
processed_outputs["pixel_values"].squeeze_(0)
processed_outputs["pixel_attention_mask"].squeeze_(0)
return processed_outputs
def _get_mm_fields_config(
@ -268,10 +406,16 @@ class Idefics3MultiModalProcessor(
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
num_patches = hf_inputs.get("num_patches", torch.empty(0))
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
pixel_attention_mask=MultiModalFieldConfig.batched("image"),
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", num_patches),
pixel_attention_mask=MultiModalFieldConfig.flat_from_sizes(
"image", num_patches),
image_embeds=MultiModalFieldConfig.batched("image"),
num_patches=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
@ -281,42 +425,18 @@ class Idefics3MultiModalProcessor(
out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token = hf_processor.image_token.content
fake_image_token = hf_processor.fake_image_token.content
global_img_token = hf_processor.global_image_tag
image_seq_len = hf_processor.image_seq_len
grid_placeholder = "<row_{n_h}_col_{n_w}>"
p_img = image_token * image_seq_len
global_img_placeholder = fake_image_token + global_img_token + p_img
tile_img_placeholder = fake_image_token + grid_placeholder + p_img
def get_replacement_idefics3(item_idx: int) -> str:
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
grid_w, grid_h = self.info._get_image_feature_grid_size(
return self.info.get_image_repl(
image_width=image_size.width,
image_height=image_size.height,
**hf_processor_mm_kwargs,
processor=hf_processor,
)
if grid_w == 0 and grid_h == 0:
image_placeholder = global_img_placeholder
else:
tiles_placeholder = list[str]()
for i in range(grid_h):
for j in range(grid_w):
placeholder_per_tile = tile_img_placeholder.format(
n_h=i + 1, n_w=j + 1)
tiles_placeholder.append(placeholder_per_tile)
# Add line break if it is the last tile in the row
if j == grid_w - 1:
tiles_placeholder.append("\n")
image_placeholder = "".join(
[*tiles_placeholder, "\n", global_img_placeholder])
return image_placeholder + fake_image_token
return [
PromptReplacement(
@ -424,73 +544,13 @@ class Idefics3Model(nn.Module):
config.vision_config.patch_size)**2) / (config.scale_factor**2))
self.image_token_id = self.config.image_token_id
def _validate_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("num_patches", *map(str, expected_dims))
raise ValueError(
"The expected shape of pixel values per image per batch "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
pixel_attention_mask = kwargs.pop("pixel_attention_mask", None)
if pixel_values is None and image_embeds is None:
return None
if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return Idefics3ImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds, concat=True),
)
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if isinstance(pixel_values, list):
pixel_values = torch.cat(pixel_values, dim=1)
pixel_attention_mask = torch.cat(pixel_attention_mask, dim=1)
else:
pixel_values = flatten_bn(pixel_values)
pixel_attention_mask = flatten_bn(pixel_attention_mask)
return Idefics3ImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
pixel_attention_mask=pixel_attention_mask)
raise AssertionError("This line should be unreachable.")
def _image_pixels_to_features(
def image_pixels_to_features(
self,
pixel_values: torch.Tensor,
pixel_attention_mask: Optional[torch.BoolTensor] = None,
) -> NestedTensors:
pixel_attention_mask: torch.Tensor,
) -> torch.Tensor:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
num_patches = [x.size(0) for x in pixel_values]
pixel_values = pixel_values.to(
dtype=self.vision_model.embeddings.patch_embedding.weight.dtype
) # fp16 compatibility
@ -502,14 +562,6 @@ class Idefics3Model(nn.Module):
pixel_values = pixel_values[real_images_inds].contiguous()
# Handle the vision attention mask
if pixel_attention_mask is None:
pixel_attention_mask = torch.ones(
size=(pixel_values.size(0), pixel_values.size(2),
pixel_values.size(3)),
dtype=torch.bool,
device=pixel_values.device,
)
else:
# Remove padding images from the mask
pixel_attention_mask = pixel_attention_mask[
real_images_inds].contiguous()
@ -529,27 +581,7 @@ class Idefics3Model(nn.Module):
patch_attention_mask=patch_attention_mask,
)
return image_hidden_states.split(num_patches)
def _process_image_pixels(
self, inputs: Idefics3ImagePixelInputs) -> NestedTensors:
assert self.vision_model is not None
pixel_values = inputs["data"]
pixel_attention_mask = inputs["pixel_attention_mask"]
return self._image_pixels_to_features(pixel_values,
pixel_attention_mask)
def _process_image_input(self, image_input: ImageInputs) -> torch.Tensor:
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_model is not None
image_features = self._process_image_pixels(image_input)
num_patches = [x.size(0) for x in image_features]
image_features = torch.cat(image_features)
return self.connector(image_features).split(num_patches)
return image_hidden_states
def get_input_embeddings(
self,
@ -616,13 +648,113 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
self.logits_processor = LogitsProcessor(config.text_config.vocab_size)
self.sampler = get_sampler()
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape)
if actual_dims != expected_dims:
expected_expr = str(expected_dims)
raise ValueError(
"The expected shape of pixel values per image per batch "
f" per patch is {expected_expr}. "
f"You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None and image_embeds is None:
return None
embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
embed_is_patch = flatten_bn(embed_is_patch)
if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return Idefics3ImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds, concat=True),
embed_is_patch=embed_is_patch,
)
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
pixel_attention_mask = kwargs.pop("pixel_attention_mask")
if not isinstance(pixel_attention_mask, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel_attention_mask. "
f"Got type: {type(pixel_attention_mask)}")
num_patches = kwargs.pop("num_patches")
if not isinstance(num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_patches. "
f"Got type: {type(num_patches)}")
pixel_values = flatten_bn(pixel_values, concat=True)
pixel_attention_mask = flatten_bn(pixel_attention_mask,
concat=True)
num_patches = flatten_bn(num_patches, concat=True)
return Idefics3ImagePixelInputs(
type="pixel_values",
pixel_values=self._validate_pixel_values(pixel_values),
pixel_attention_mask=pixel_attention_mask,
num_patches=num_patches,
embed_is_patch=embed_is_patch,
)
raise AssertionError("This line should be unreachable.")
def _process_image_pixels(
self, inputs: Idefics3ImagePixelInputs) -> torch.Tensor:
pixel_values = inputs["pixel_values"]
pixel_attention_mask = inputs["pixel_attention_mask"]
return self.model.image_pixels_to_features(
pixel_values,
pixel_attention_mask=pixel_attention_mask,
)
def _process_image_input(self, image_input: ImageInputs) -> torch.Tensor:
if image_input["type"] == "image_embeds":
return image_input["data"]
image_features = self._process_image_pixels(image_input)
image_features = self.model.connector(image_features)
num_patches = image_input["num_patches"]
return image_features.split(num_patches.tolist())
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self.model._parse_and_validate_image_input(**kwargs)
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self.model._process_image_input(image_input)
return vision_embeddings
image_features = self._process_image_input(image_input)
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings(
self,
@ -632,8 +764,11 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds = self.model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.image_token_id)
input_ids,
inputs_embeds,
select_patch_features(multimodal_embeddings),
self.config.image_token_id,
)
return inputs_embeds
def forward(

View File

@ -21,7 +21,6 @@ from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers.models.mllama.configuration_mllama as config_mllama
from PIL.Image import Image
from torch import nn

View File

@ -160,7 +160,7 @@ class Qwen2AudioMultiModalProcessor(
mm_kwargs: Mapping[str, Any],
) -> BatchFeature:
# Text-only input not supported in composite processor
if not mm_data or not mm_data.get("audios", []):
if not mm_data.get("audios", []):
prompt_ids = self.info.get_tokenizer().encode(prompt)
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

View File

@ -8,7 +8,6 @@ from functools import cached_property
from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import functional as F
from transformers import BatchFeature, ProcessorMixin
@ -160,7 +159,7 @@ class UltravoxMultiModalProcessor(
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
# Text-only input not supported in composite processor
if not mm_data or not mm_data.get("audios", []):
if not mm_data.get("audios", []):
prompt_ids = self.info.get_tokenizer().encode(
prompt, add_special_tokens=False)
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)