mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 02:05:00 +08:00
Migrate Idefics3ImagePixelInputs and Idefics3ImageEmbeddingInputs to … (#21683)
Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
parent
75856bc2cb
commit
3ea57a56d9
@ -18,7 +18,7 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from typing import Literal, Optional, TypedDict, Union
|
from typing import Annotated, Literal, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -45,6 +45,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
from .idefics2_vision_model import (
|
from .idefics2_vision_model import (
|
||||||
@ -56,26 +57,30 @@ from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
|||||||
merge_multimodal_embeddings)
|
merge_multimodal_embeddings)
|
||||||
|
|
||||||
|
|
||||||
class Idefics3ImagePixelInputs(TypedDict):
|
class Idefics3ImagePixelInputs(TensorSchema):
|
||||||
|
"""
|
||||||
|
Dimensions:
|
||||||
|
- bn: Batch size * number of images
|
||||||
|
- bnp: Batch size * number of images * number of patches
|
||||||
|
- c: Number of channels (3)
|
||||||
|
- h: Height
|
||||||
|
- w: Width
|
||||||
|
"""
|
||||||
type: Literal["pixel_values"]
|
type: Literal["pixel_values"]
|
||||||
pixel_values: torch.Tensor
|
pixel_values: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")]
|
||||||
"""
|
|
||||||
Shape: `(batch_size * num_images * num_patches,
|
|
||||||
num_channels, height, width)`
|
|
||||||
"""
|
|
||||||
pixel_attention_mask: torch.Tensor
|
pixel_attention_mask: torch.Tensor
|
||||||
|
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
|
||||||
num_patches: torch.Tensor
|
|
||||||
"""Shape: `(batch_size * num_images)`"""
|
|
||||||
|
|
||||||
|
|
||||||
class Idefics3ImageEmbeddingInputs(TypedDict):
|
class Idefics3ImageEmbeddingInputs(TensorSchema):
|
||||||
|
"""
|
||||||
|
Dimensions:
|
||||||
|
- bn: Batch size * number of images
|
||||||
|
- f: Image feature size
|
||||||
|
- h: Hidden size (must match the hidden size of language model backbone)
|
||||||
|
"""
|
||||||
type: Literal["image_embeds"]
|
type: Literal["image_embeds"]
|
||||||
data: torch.Tensor
|
data: Annotated[torch.Tensor, TensorShape("bn", "f", "h")]
|
||||||
"""
|
|
||||||
Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
|
|
||||||
`hidden_size` must match the hidden size of language model backbone.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs]
|
ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs]
|
||||||
@ -614,25 +619,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
self.lm_head.weight = self.model.text_model.wte.weight
|
self.lm_head.weight = self.model.text_model.wte.weight
|
||||||
self.logits_processor = LogitsProcessor(config.text_config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.text_config.vocab_size)
|
||||||
|
|
||||||
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
|
||||||
h = w = self.config.vision_config.image_size
|
|
||||||
expected_dims = (3, h, w)
|
|
||||||
|
|
||||||
def _validate_shape(d: torch.Tensor):
|
|
||||||
actual_dims = tuple(d.shape)
|
|
||||||
|
|
||||||
if actual_dims != expected_dims:
|
|
||||||
expected_expr = str(expected_dims)
|
|
||||||
raise ValueError(
|
|
||||||
"The expected shape of pixel values per image per batch "
|
|
||||||
f" per patch is {expected_expr}. "
|
|
||||||
f"You supplied {tuple(d.shape)}.")
|
|
||||||
|
|
||||||
for d in data:
|
|
||||||
_validate_shape(d)
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
def _parse_and_validate_image_input(
|
def _parse_and_validate_image_input(
|
||||||
self, **kwargs: object) -> Optional[ImageInputs]:
|
self, **kwargs: object) -> Optional[ImageInputs]:
|
||||||
pixel_values = kwargs.pop("pixel_values", None)
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
@ -666,16 +652,17 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
raise ValueError("Incorrect type of num_patches. "
|
raise ValueError("Incorrect type of num_patches. "
|
||||||
f"Got type: {type(num_patches)}")
|
f"Got type: {type(num_patches)}")
|
||||||
|
|
||||||
pixel_values = flatten_bn(pixel_values, concat=True)
|
expected_h = expected_w = self.config.vision_config.image_size
|
||||||
pixel_attention_mask = flatten_bn(pixel_attention_mask,
|
|
||||||
concat=True)
|
|
||||||
num_patches = flatten_bn(num_patches, concat=True)
|
|
||||||
|
|
||||||
return Idefics3ImagePixelInputs(
|
return Idefics3ImagePixelInputs(
|
||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
pixel_values=self._validate_pixel_values(pixel_values),
|
pixel_values=flatten_bn(pixel_values, concat=True),
|
||||||
pixel_attention_mask=pixel_attention_mask,
|
pixel_attention_mask=flatten_bn(pixel_attention_mask,
|
||||||
num_patches=num_patches,
|
concat=True),
|
||||||
|
num_patches=flatten_bn(num_patches, concat=True),
|
||||||
|
resolve_bindings={
|
||||||
|
"h": expected_h,
|
||||||
|
"w": expected_w
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
raise AssertionError("This line should be unreachable.")
|
raise AssertionError("This line should be unreachable.")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user