mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:24:54 +08:00
[Bugfix] embed_is_patch for Idefics3 (#15696)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
3b00ff9138
commit
541d1df486
@ -24,7 +24,6 @@
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from transformers import CohereConfig
|
||||
|
||||
|
||||
@ -17,16 +17,14 @@
|
||||
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Dict, List, Literal, Optional, Set, Tuple, TypedDict, Union
|
||||
from typing import Dict, Literal, Optional, Set, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from transformers import (BatchFeature, Idefics3Config, Idefics3ImageProcessor,
|
||||
Idefics3Processor)
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
@ -35,13 +33,16 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.multimodal.parse import ImageProcessorItems
|
||||
from vllm.multimodal.parse import ImageProcessorItems, ImageSize
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
MultiModalDataItems,
|
||||
MultiModalFieldConfig,
|
||||
PromptReplacement, PromptUpdate)
|
||||
PromptReplacement, PromptUpdate,
|
||||
encode_tokens)
|
||||
# yapf: enable
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
@ -53,18 +54,28 @@ from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
|
||||
from .llama import LlamaModel
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
from .vision import scatter_patch_features, select_patch_features
|
||||
|
||||
|
||||
class Idefics3ImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: torch.Tensor
|
||||
pixel_values: torch.Tensor
|
||||
"""
|
||||
Shape: `(batch_size * num_images * num_patches,
|
||||
num_channels, height, width)`
|
||||
"""
|
||||
pixel_attention_mask: Optional[torch.BoolTensor]
|
||||
pixel_attention_mask: torch.Tensor
|
||||
|
||||
num_patches: torch.Tensor
|
||||
"""Shape: `(batch_size * num_images)`"""
|
||||
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which image embeddings correspond
|
||||
to patch tokens.
|
||||
|
||||
Shape: `(batch_size * num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
|
||||
class Idefics3ImageEmbeddingInputs(TypedDict):
|
||||
@ -75,6 +86,14 @@ class Idefics3ImageEmbeddingInputs(TypedDict):
|
||||
`hidden_size` must match the hidden size of language model backbone.
|
||||
"""
|
||||
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which image embeddings correspond
|
||||
to patch tokens.
|
||||
|
||||
Shape: `(batch_size * num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
|
||||
ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs]
|
||||
|
||||
@ -100,32 +119,14 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
hf_processor = self.get_hf_processor()
|
||||
image_processor: Idefics3ImageProcessor = hf_processor.image_processor
|
||||
grid_w, grid_h = self._get_image_feature_grid_size(
|
||||
image_width=image_processor.size['longest_edge'],
|
||||
image_height=image_processor.size['longest_edge'],
|
||||
)
|
||||
num_image_token = (grid_w * grid_h + 1) * hf_processor.image_seq_len
|
||||
# Calculate Non-image-token length
|
||||
# NOTE: <row_1_col_1> and <global-img> are special token for SmolVLM
|
||||
# but not for Idefic3, so we need to tokenize them to get actual length.
|
||||
tokenizer = self.get_tokenizer()
|
||||
tile_token_len = len(tokenizer.tokenize("<row_1_col_1>"))
|
||||
glob_token_len = len(tokenizer.tokenize(hf_processor.global_image_tag))
|
||||
# linebreak and <fake_token_around_image> always cost 1 token
|
||||
fake_token_len = lb_len = 1
|
||||
non_image_token = (grid_w * grid_h) * (
|
||||
tile_token_len + fake_token_len) + glob_token_len + (
|
||||
grid_h + 1) * lb_len + fake_token_len
|
||||
return {"image": num_image_token + non_image_token}
|
||||
return {"image": self.get_max_image_tokens()}
|
||||
|
||||
def _resize_output_size(self,
|
||||
*,
|
||||
height: int,
|
||||
width: int,
|
||||
max_len: Optional[int] = None,
|
||||
min_len: Optional[int] = 1,
|
||||
min_len: int = 1,
|
||||
max_size: Optional[int] = None) -> tuple[int, int]:
|
||||
# Set default value for max_len if not provided
|
||||
max_len = max(height, width) if max_len is None else max_len
|
||||
@ -181,10 +182,13 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
size: Optional[dict[str, object]] = None,
|
||||
processor: Optional[Idefics3Processor],
|
||||
) -> tuple[int, int]:
|
||||
hf_processor = self.get_hf_processor(size=size)
|
||||
image_processor: Idefics3ImageProcessor = hf_processor.image_processor
|
||||
if processor is None:
|
||||
processor = self.get_hf_processor()
|
||||
|
||||
image_processor: Idefics3ImageProcessor = processor.image_processor
|
||||
|
||||
max_image_size = image_processor.max_image_size['longest_edge']
|
||||
size = image_processor.size['longest_edge']
|
||||
assert size % max_image_size == 0, (
|
||||
@ -204,6 +208,105 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
|
||||
grid_h = grid_w = 0
|
||||
return grid_w, grid_h
|
||||
|
||||
def get_num_patches(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
processor: Optional[Idefics3Processor],
|
||||
) -> int:
|
||||
grid_w, grid_h = self._get_image_feature_grid_size(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
return grid_w * grid_h + 1
|
||||
|
||||
def get_image_repl(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
processor: Optional[Idefics3Processor],
|
||||
) -> str:
|
||||
if processor is None:
|
||||
processor = self.get_hf_processor()
|
||||
|
||||
image_token = processor.image_token.content
|
||||
fake_image_token = processor.fake_image_token.content
|
||||
global_img_token = processor.global_image_tag
|
||||
image_seq_len = processor.image_seq_len
|
||||
grid_placeholder = "<row_{n_h}_col_{n_w}>"
|
||||
|
||||
p_img = image_token * image_seq_len
|
||||
global_img_placeholder = fake_image_token + global_img_token + p_img
|
||||
tile_img_placeholder = fake_image_token + grid_placeholder + p_img
|
||||
|
||||
grid_w, grid_h = self._get_image_feature_grid_size(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
processor=processor,
|
||||
)
|
||||
if grid_w == 0 and grid_h == 0:
|
||||
return global_img_placeholder + fake_image_token
|
||||
|
||||
tiles_placeholder = list[str]()
|
||||
for i in range(grid_h):
|
||||
for j in range(grid_w):
|
||||
placeholder_per_tile = tile_img_placeholder.format(n_h=i + 1,
|
||||
n_w=j + 1)
|
||||
tiles_placeholder.append(placeholder_per_tile)
|
||||
# Add line break if it is the last tile in the row
|
||||
if j == grid_w - 1:
|
||||
tiles_placeholder.append("\n")
|
||||
|
||||
return "".join([
|
||||
*tiles_placeholder,
|
||||
"\n",
|
||||
global_img_placeholder,
|
||||
fake_image_token,
|
||||
])
|
||||
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
processor: Optional[Idefics3Processor],
|
||||
) -> int:
|
||||
tokenizer = self.get_tokenizer()
|
||||
image_repl = self.get_image_repl(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
image_repl_tokens = encode_tokens(
|
||||
tokenizer,
|
||||
image_repl,
|
||||
add_special_tokens=False,
|
||||
)
|
||||
return len(image_repl_tokens)
|
||||
|
||||
def get_image_size_with_most_features(self) -> ImageSize:
|
||||
processor = self.get_hf_processor()
|
||||
image_processor: Idefics3ImageProcessor = processor.image_processor
|
||||
|
||||
return ImageSize(
|
||||
width=image_processor.size["longest_edge"],
|
||||
height=image_processor.size["longest_edge"],
|
||||
)
|
||||
|
||||
def get_max_image_tokens(self) -> int:
|
||||
target_width, target_height = self.get_image_size_with_most_features()
|
||||
|
||||
return self.get_num_image_tokens(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
processor=None,
|
||||
)
|
||||
|
||||
|
||||
class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]
|
||||
):
|
||||
@ -217,7 +320,7 @@ class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]
|
||||
hf_processor = self.info.get_hf_processor()
|
||||
image_processor: Idefics3ImageProcessor = hf_processor.image_processor
|
||||
longest_edge = image_processor.max_image_size['longest_edge']
|
||||
image_token: str = hf_processor.image_token.content
|
||||
image_token = hf_processor.image_token.content
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
@ -241,26 +344,61 @@ class Idefics3MultiModalProcessor(
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
if mm_data:
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
prompt, mm_data, mm_kwargs)
|
||||
image_grids = [
|
||||
self.info._get_image_feature_grid_size(
|
||||
image_width=img.width,
|
||||
image_height=img.height,
|
||||
**mm_kwargs,
|
||||
) for img in mm_data["images"]
|
||||
]
|
||||
image_patches = list(map(lambda x: math.prod(x) + 1, image_grids))
|
||||
for key in ("pixel_values", "pixel_attention_mask"):
|
||||
data = processed_outputs.pop(key)
|
||||
data = data.flatten(0, 1).split(image_patches)
|
||||
processed_outputs[key] = data
|
||||
else:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
processed_outputs = tokenizer(prompt,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt")
|
||||
# Text-only input not supported in composite processor
|
||||
if not (images := mm_data.get("images", [])):
|
||||
prompt_ids = self.info.get_tokenizer().encode(prompt)
|
||||
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
|
||||
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
|
||||
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
prompt,
|
||||
mm_data,
|
||||
mm_kwargs,
|
||||
)
|
||||
|
||||
parsed_images = (self._get_data_parser().parse_mm_data({
|
||||
"image": images
|
||||
}).get_items("image", ImageProcessorItems))
|
||||
image_sizes = [
|
||||
parsed_images.get_image_size(i) for i in range(len(parsed_images))
|
||||
]
|
||||
hf_processor = self.info.get_hf_processor(**mm_kwargs)
|
||||
|
||||
image_repl_features = [
|
||||
self.info.get_image_repl(image_width=size.width,
|
||||
image_height=size.height,
|
||||
processor=hf_processor)
|
||||
for size in image_sizes
|
||||
]
|
||||
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
image_repls_feature_tokens = [
|
||||
tokenizer.encode(image_repl, add_special_tokens=False)
|
||||
for image_repl in image_repl_features
|
||||
]
|
||||
|
||||
vocab = tokenizer.get_vocab()
|
||||
image_token_id = vocab[hf_processor.image_token.content]
|
||||
|
||||
embed_is_patch = [
|
||||
torch.tensor(image_repl_tokens) == image_token_id
|
||||
for image_repl_tokens in image_repls_feature_tokens
|
||||
]
|
||||
processed_outputs["embed_is_patch"] = embed_is_patch
|
||||
|
||||
num_patches = [
|
||||
self.info.get_num_patches(
|
||||
image_width=size.width,
|
||||
image_height=size.height,
|
||||
processor=hf_processor,
|
||||
) for size in image_sizes
|
||||
]
|
||||
processed_outputs["num_patches"] = torch.tensor(num_patches)
|
||||
|
||||
# Remove the extra batch dimension
|
||||
processed_outputs["pixel_values"].squeeze_(0)
|
||||
processed_outputs["pixel_attention_mask"].squeeze_(0)
|
||||
|
||||
return processed_outputs
|
||||
|
||||
def _get_mm_fields_config(
|
||||
@ -268,10 +406,16 @@ class Idefics3MultiModalProcessor(
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
num_patches = hf_inputs.get("num_patches", torch.empty(0))
|
||||
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
pixel_attention_mask=MultiModalFieldConfig.batched("image"),
|
||||
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", num_patches),
|
||||
pixel_attention_mask=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", num_patches),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
num_patches=MultiModalFieldConfig.batched("image"),
|
||||
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
@ -281,42 +425,18 @@ class Idefics3MultiModalProcessor(
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
|
||||
image_token = hf_processor.image_token.content
|
||||
fake_image_token = hf_processor.fake_image_token.content
|
||||
global_img_token = hf_processor.global_image_tag
|
||||
image_seq_len = hf_processor.image_seq_len
|
||||
grid_placeholder = "<row_{n_h}_col_{n_w}>"
|
||||
|
||||
p_img = image_token * image_seq_len
|
||||
global_img_placeholder = fake_image_token + global_img_token + p_img
|
||||
tile_img_placeholder = fake_image_token + grid_placeholder + p_img
|
||||
|
||||
def get_replacement_idefics3(item_idx: int) -> str:
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
|
||||
image_size = images.get_image_size(item_idx)
|
||||
grid_w, grid_h = self.info._get_image_feature_grid_size(
|
||||
|
||||
return self.info.get_image_repl(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
**hf_processor_mm_kwargs,
|
||||
processor=hf_processor,
|
||||
)
|
||||
if grid_w == 0 and grid_h == 0:
|
||||
image_placeholder = global_img_placeholder
|
||||
else:
|
||||
tiles_placeholder = list[str]()
|
||||
for i in range(grid_h):
|
||||
for j in range(grid_w):
|
||||
placeholder_per_tile = tile_img_placeholder.format(
|
||||
n_h=i + 1, n_w=j + 1)
|
||||
tiles_placeholder.append(placeholder_per_tile)
|
||||
# Add line break if it is the last tile in the row
|
||||
if j == grid_w - 1:
|
||||
tiles_placeholder.append("\n")
|
||||
|
||||
image_placeholder = "".join(
|
||||
[*tiles_placeholder, "\n", global_img_placeholder])
|
||||
return image_placeholder + fake_image_token
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
@ -424,73 +544,13 @@ class Idefics3Model(nn.Module):
|
||||
config.vision_config.patch_size)**2) / (config.scale_factor**2))
|
||||
self.image_token_id = self.config.image_token_id
|
||||
|
||||
def _validate_pixel_values(
|
||||
self, data: Union[torch.Tensor, List[torch.Tensor]]
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
|
||||
h = w = self.config.vision_config.image_size
|
||||
expected_dims = (3, h, w)
|
||||
|
||||
def _validate_shape(d: torch.Tensor):
|
||||
actual_dims = tuple(d.shape[1:])
|
||||
|
||||
if actual_dims != expected_dims:
|
||||
expected_expr = ("num_patches", *map(str, expected_dims))
|
||||
raise ValueError(
|
||||
"The expected shape of pixel values per image per batch "
|
||||
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
|
||||
|
||||
for d in data:
|
||||
_validate_shape(d)
|
||||
|
||||
return data
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[ImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
pixel_attention_mask = kwargs.pop("pixel_attention_mask", None)
|
||||
|
||||
if pixel_values is None and image_embeds is None:
|
||||
return None
|
||||
|
||||
if image_embeds is not None:
|
||||
if not isinstance(image_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
|
||||
return Idefics3ImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
data=flatten_bn(image_embeds, concat=True),
|
||||
)
|
||||
|
||||
if pixel_values is not None:
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
if isinstance(pixel_values, list):
|
||||
pixel_values = torch.cat(pixel_values, dim=1)
|
||||
pixel_attention_mask = torch.cat(pixel_attention_mask, dim=1)
|
||||
else:
|
||||
pixel_values = flatten_bn(pixel_values)
|
||||
pixel_attention_mask = flatten_bn(pixel_attention_mask)
|
||||
|
||||
return Idefics3ImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(pixel_values),
|
||||
pixel_attention_mask=pixel_attention_mask)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _image_pixels_to_features(
|
||||
def image_pixels_to_features(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
) -> NestedTensors:
|
||||
pixel_attention_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# NOTE: we skip the step to select the vision feature layer since
|
||||
# this is already done inside the vision tower
|
||||
num_patches = [x.size(0) for x in pixel_values]
|
||||
pixel_values = pixel_values.to(
|
||||
dtype=self.vision_model.embeddings.patch_embedding.weight.dtype
|
||||
) # fp16 compatibility
|
||||
@ -502,17 +562,9 @@ class Idefics3Model(nn.Module):
|
||||
pixel_values = pixel_values[real_images_inds].contiguous()
|
||||
|
||||
# Handle the vision attention mask
|
||||
if pixel_attention_mask is None:
|
||||
pixel_attention_mask = torch.ones(
|
||||
size=(pixel_values.size(0), pixel_values.size(2),
|
||||
pixel_values.size(3)),
|
||||
dtype=torch.bool,
|
||||
device=pixel_values.device,
|
||||
)
|
||||
else:
|
||||
# Remove padding images from the mask
|
||||
pixel_attention_mask = pixel_attention_mask[
|
||||
real_images_inds].contiguous()
|
||||
# Remove padding images from the mask
|
||||
pixel_attention_mask = pixel_attention_mask[
|
||||
real_images_inds].contiguous()
|
||||
|
||||
patch_size = self.config.vision_config.patch_size
|
||||
patches_subgrid = pixel_attention_mask.unfold(dimension=1,
|
||||
@ -529,27 +581,7 @@ class Idefics3Model(nn.Module):
|
||||
patch_attention_mask=patch_attention_mask,
|
||||
)
|
||||
|
||||
return image_hidden_states.split(num_patches)
|
||||
|
||||
def _process_image_pixels(
|
||||
self, inputs: Idefics3ImagePixelInputs) -> NestedTensors:
|
||||
assert self.vision_model is not None
|
||||
|
||||
pixel_values = inputs["data"]
|
||||
pixel_attention_mask = inputs["pixel_attention_mask"]
|
||||
|
||||
return self._image_pixels_to_features(pixel_values,
|
||||
pixel_attention_mask)
|
||||
|
||||
def _process_image_input(self, image_input: ImageInputs) -> torch.Tensor:
|
||||
if image_input["type"] == "image_embeds":
|
||||
return image_input["data"]
|
||||
|
||||
assert self.vision_model is not None
|
||||
image_features = self._process_image_pixels(image_input)
|
||||
num_patches = [x.size(0) for x in image_features]
|
||||
image_features = torch.cat(image_features)
|
||||
return self.connector(image_features).split(num_patches)
|
||||
return image_hidden_states
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
@ -616,13 +648,113 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self.logits_processor = LogitsProcessor(config.text_config.vocab_size)
|
||||
self.sampler = get_sampler()
|
||||
|
||||
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
||||
h = w = self.config.vision_config.image_size
|
||||
expected_dims = (3, h, w)
|
||||
|
||||
def _validate_shape(d: torch.Tensor):
|
||||
actual_dims = tuple(d.shape)
|
||||
|
||||
if actual_dims != expected_dims:
|
||||
expected_expr = str(expected_dims)
|
||||
raise ValueError(
|
||||
"The expected shape of pixel values per image per batch "
|
||||
f" per patch is {expected_expr}. "
|
||||
f"You supplied {tuple(d.shape)}.")
|
||||
|
||||
for d in data:
|
||||
_validate_shape(d)
|
||||
|
||||
return data
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[ImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
|
||||
if pixel_values is None and image_embeds is None:
|
||||
return None
|
||||
|
||||
embed_is_patch = kwargs.pop("embed_is_patch")
|
||||
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of embed_is_patch. "
|
||||
f"Got type: {type(embed_is_patch)}")
|
||||
|
||||
embed_is_patch = flatten_bn(embed_is_patch)
|
||||
|
||||
if image_embeds is not None:
|
||||
if not isinstance(image_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
|
||||
return Idefics3ImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
data=flatten_bn(image_embeds, concat=True),
|
||||
embed_is_patch=embed_is_patch,
|
||||
)
|
||||
|
||||
if pixel_values is not None:
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
pixel_attention_mask = kwargs.pop("pixel_attention_mask")
|
||||
if not isinstance(pixel_attention_mask, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel_attention_mask. "
|
||||
f"Got type: {type(pixel_attention_mask)}")
|
||||
|
||||
num_patches = kwargs.pop("num_patches")
|
||||
if not isinstance(num_patches, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of num_patches. "
|
||||
f"Got type: {type(num_patches)}")
|
||||
|
||||
pixel_values = flatten_bn(pixel_values, concat=True)
|
||||
pixel_attention_mask = flatten_bn(pixel_attention_mask,
|
||||
concat=True)
|
||||
num_patches = flatten_bn(num_patches, concat=True)
|
||||
|
||||
return Idefics3ImagePixelInputs(
|
||||
type="pixel_values",
|
||||
pixel_values=self._validate_pixel_values(pixel_values),
|
||||
pixel_attention_mask=pixel_attention_mask,
|
||||
num_patches=num_patches,
|
||||
embed_is_patch=embed_is_patch,
|
||||
)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _process_image_pixels(
|
||||
self, inputs: Idefics3ImagePixelInputs) -> torch.Tensor:
|
||||
pixel_values = inputs["pixel_values"]
|
||||
pixel_attention_mask = inputs["pixel_attention_mask"]
|
||||
|
||||
return self.model.image_pixels_to_features(
|
||||
pixel_values,
|
||||
pixel_attention_mask=pixel_attention_mask,
|
||||
)
|
||||
|
||||
def _process_image_input(self, image_input: ImageInputs) -> torch.Tensor:
|
||||
if image_input["type"] == "image_embeds":
|
||||
return image_input["data"]
|
||||
|
||||
image_features = self._process_image_pixels(image_input)
|
||||
image_features = self.model.connector(image_features)
|
||||
|
||||
num_patches = image_input["num_patches"]
|
||||
return image_features.split(num_patches.tolist())
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
image_input = self.model._parse_and_validate_image_input(**kwargs)
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return None
|
||||
vision_embeddings = self.model._process_image_input(image_input)
|
||||
return vision_embeddings
|
||||
|
||||
image_features = self._process_image_input(image_input)
|
||||
|
||||
return scatter_patch_features(
|
||||
image_features,
|
||||
image_input["embed_is_patch"],
|
||||
)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
@ -632,8 +764,11 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
self.config.image_token_id)
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
select_patch_features(multimodal_embeddings),
|
||||
self.config.image_token_id,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
|
||||
@ -21,7 +21,6 @@ from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
import transformers.models.mllama.configuration_mllama as config_mllama
|
||||
from PIL.Image import Image
|
||||
from torch import nn
|
||||
|
||||
@ -160,7 +160,7 @@ class Qwen2AudioMultiModalProcessor(
|
||||
mm_kwargs: Mapping[str, Any],
|
||||
) -> BatchFeature:
|
||||
# Text-only input not supported in composite processor
|
||||
if not mm_data or not mm_data.get("audios", []):
|
||||
if not mm_data.get("audios", []):
|
||||
prompt_ids = self.info.get_tokenizer().encode(prompt)
|
||||
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
|
||||
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
|
||||
|
||||
@ -8,7 +8,6 @@ from functools import cached_property
|
||||
from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from transformers import BatchFeature, ProcessorMixin
|
||||
@ -160,7 +159,7 @@ class UltravoxMultiModalProcessor(
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
# Text-only input not supported in composite processor
|
||||
if not mm_data or not mm_data.get("audios", []):
|
||||
if not mm_data.get("audios", []):
|
||||
prompt_ids = self.info.get_tokenizer().encode(
|
||||
prompt, add_special_tokens=False)
|
||||
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user