[Bugfix][VLM] Fix incompatibility between #7902 and #7230 (#7948)

This commit is contained in:
Cyrus Leung 2024-08-28 23:11:18 +08:00 committed by GitHub
parent 98c12cffe5
commit ef9baee3c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 120 additions and 92 deletions

View File

@ -40,13 +40,13 @@ BLIP2_IMAGE_TOKEN_ID = 50265
class Blip2ImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: (batch_size, num_channels, height, width)"""
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
class Blip2ImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size, image_feature_size, hidden_size)`
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""

View File

@ -53,7 +53,7 @@ CHAMELEON_SEP_TOKEN_ID = 8710
class ChameleonImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: `(batch_size, num_channels, height, width)`"""
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
def get_max_chameleon_image_tokens(ctx: InputContext):

View File

@ -29,7 +29,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
get_clip_num_patches)
from .interfaces import SupportsMultiModal
from .utils import (filter_weights, init_vllm_registered_model,
from .utils import (filter_weights, flatten_bn, init_vllm_registered_model,
merge_multimodal_embeddings)
IMG_START = '<img>'
@ -42,19 +42,17 @@ IMAGENET_STD = (0.229, 0.224, 0.225)
class InternVLImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: Union[torch.Tensor, List[torch.Tensor]]
data: torch.Tensor
"""
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
Note that `num_patches` may be different for each batch, in which case
the data is passed as a list instead of a batched tensor.
Shape:
`(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
"""
class InternVLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: Union[torch.Tensor, List[torch.Tensor]]
"""Shape: `(batch_size, image_feature_size, hidden_size)`
data: torch.Tensor
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
@ -357,7 +355,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
x = x.permute(0, 2, 1, 3).contiguous()
return x
def extract_feature(self, pixel_values):
def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
vit_embeds = self.vision_model(pixel_values=pixel_values)
vit_embeds = vit_embeds[:, 1:, :]
@ -370,17 +368,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
vit_embeds = self.mlp1(vit_embeds)
return vit_embeds
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
if list(data.shape[1:]) != [2]:
raise ValueError(
f"The expected image sizes shape is batch dimension plus "
f"{[2]}. You supplied {data.shape}.")
return data
def _validate_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
@ -389,10 +377,11 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
actual_dims = tuple(d.shape)
if actual_dims != expected_dims:
expected_expr = ("num_patches", *map(str, expected_dims))
expected_expr = str(expected_dims)
raise ValueError(
"The expected shape of pixel values in each batch element "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
"The expected shape of pixel values per image per batch "
f" per patch is {expected_expr}. "
f"You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
@ -413,12 +402,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
# Flatten the B and N dimensions
image_embeds = image_embeds.flatten(0, 2)
return InternVLImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
data=flatten_bn(image_embeds),
)
self.img_context_token_id = image_token_id[0]
@ -428,12 +414,10 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
# Flatten the B and N dimensions
pixel_values = pixel_values.flatten(0, 2)
return InternVLImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
data=self._validate_pixel_values(
flatten_bn(pixel_values, concat=True).flatten(0, 1)),
)
raise AssertionError("This line should be unreachable.")

View File

@ -30,13 +30,13 @@ from .utils import (filter_weights, init_vllm_registered_model,
class LlavaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: `(batch_size, num_channels, height, width)`"""
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
class LlavaImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size, image_feature_size, hidden_size)`
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""

View File

@ -29,7 +29,7 @@ from .llava import LlavaMultiModalProjector
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip)
from .utils import (filter_weights, init_vllm_registered_model,
from .utils import (filter_weights, flatten_bn, init_vllm_registered_model,
merge_multimodal_embeddings)
logger = init_logger(__name__)
@ -47,15 +47,16 @@ class LlavaNextImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: Union[torch.Tensor, List[torch.Tensor]]
"""
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
Shape:
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
Note that `num_patches` may be different for each batch, in which case
the data is passed as a list instead of a batched tensor.
Note that `num_patches` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
"""
image_sizes: NotRequired[torch.Tensor]
"""
Shape: `(batch_size, 2)`
Shape: `(batch_size * num_images, 2)`
This should be in `(height, width)` format.
"""
@ -64,7 +65,7 @@ class LlavaNextImagePixelInputs(TypedDict):
class LlavaNextImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size, image_feature_size, hidden_size)`
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
@ -315,10 +316,19 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
torch.empty(config.text_config.hidden_size))
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
if list(data.shape[1:]) != [2]:
raise ValueError(
f"The expected image sizes shape is batch dimension plus "
f"{[2]}. You supplied {data.shape}.")
expected_dims = (2, )
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape)
if actual_dims != expected_dims:
expected_expr = str(expected_dims)
raise ValueError(
f"The expected shape of image sizes per image per batch "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
@ -335,7 +345,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
if actual_dims != expected_dims:
expected_expr = ("num_patches", *map(str, expected_dims))
raise ValueError(
"The expected shape of pixel values in each batch element "
"The expected shape of pixel values per image per batch "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
@ -357,22 +367,15 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if not isinstance(image_sizes, torch.Tensor):
if not isinstance(image_sizes, (torch.Tensor, list)):
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")
# Remove the N dimension until multiple images are supported.
if isinstance(pixel_values, torch.Tensor):
pixel_values = pixel_values.squeeze(1)
else:
pixel_values = [t.squeeze(0) for t in pixel_values]
image_sizes = image_sizes.squeeze(1)
return LlavaNextImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
image_sizes=self._validate_image_sizes(image_sizes),
data=self._validate_pixel_values(flatten_bn(pixel_values)),
image_sizes=self._validate_image_sizes(
flatten_bn(image_sizes, concat=True)),
)
if image_embeds is not None:
@ -380,12 +383,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
raise ValueError("Incorrect type of image embeds. "
f"Got type: {type(image_embeds)}")
# Remove the N dimension until multiple images are supported.
image_embeds = image_embeds.squeeze(1)
return LlavaNextImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
data=flatten_bn(image_embeds),
)
raise AssertionError("This line should be unreachable.")

View File

@ -34,13 +34,13 @@ _KEYS_TO_MODIFY_MAPPING = {
class PaliGemmaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: (batch_size, num_channels, height, width)"""
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
class PaliGemmaImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size, image_feature_size, hidden_size)`
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""

View File

@ -44,7 +44,7 @@ from vllm.utils import is_list_of
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsMultiModal
from .utils import merge_multimodal_embeddings
from .utils import flatten_bn, merge_multimodal_embeddings
logger = init_logger(__name__)
@ -75,15 +75,16 @@ class Phi3VImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: Union[torch.Tensor, List[torch.Tensor]]
"""
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
Shape:
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
Note that `num_patches` may be different for each batch, in which case
the data is passed as a list instead of a batched tensor.
Note that `num_patches` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
"""
image_sizes: torch.Tensor
"""
Shape: `(batch_size, 2)`
Shape: `(batch_size * num_images, 2)`
This should be in `(height, width)` format.
"""
@ -92,7 +93,7 @@ class Phi3VImagePixelInputs(TypedDict):
class Phi3VImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: Union[torch.Tensor, List[torch.Tensor]]
"""Shape: `(batch_size, image_feature_size, hidden_size)`
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
@ -511,10 +512,19 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
self.sampler = Sampler()
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
if list(data.shape[1:]) != [2]:
raise ValueError(
f"The expected shape of image sizes is batch dimension plus "
f"{[2]}. You supplied {tuple(data.shape)}.")
expected_dims = (2, )
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape)
if actual_dims != expected_dims:
expected_expr = str(expected_dims)
raise ValueError(
f"The expected shape of image sizes per image per batch "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
@ -531,7 +541,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
if actual_dims != expected_dims:
expected_expr = ("num_patches", *map(str, expected_dims))
raise ValueError(
"The expected shape of pixel values in each batch element "
"The expected shape of pixel values per image per batch "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
@ -556,30 +566,24 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if not isinstance(image_sizes, torch.Tensor):
if not isinstance(image_sizes, (torch.Tensor, list)):
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")
# Merge the B and N dimensions.
if isinstance(pixel_values, torch.Tensor):
pixel_values = pixel_values.flatten(0, 1)
else:
pixel_values = torch.cat(pixel_values)
image_sizes = image_sizes.flatten(0, 1)
return Phi3VImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
image_sizes=self._validate_image_sizes(image_sizes))
data=self._validate_pixel_values(flatten_bn(pixel_values)),
image_sizes=self._validate_image_sizes(
flatten_bn(image_sizes, concat=True)))
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return Phi3VImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
data=flatten_bn(image_embeds),
)
raise AssertionError("This line should be unreachable.")

View File

@ -49,7 +49,7 @@ logger = init_logger(__name__)
class UltravoxAudioFeatureInputs(TypedDict):
type: Literal["audio_features"]
data: Union[torch.Tensor, List[torch.Tensor]]
"""Shape: `(batch_size, 80, M)"""
"""Shape: `(batch_size * num_audios, 80, M)"""
class UltravoxAudioEmbeddingInputs(TypedDict):

View File

@ -1,4 +1,5 @@
from typing import Dict, Iterable, List, Optional, Protocol, Tuple
from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple,
Union, overload)
import numpy as np
import torch
@ -55,6 +56,44 @@ def init_vllm_registered_model(
)
@overload
def flatten_bn(x: torch.Tensor) -> torch.Tensor:
...
@overload
def flatten_bn(x: List[torch.Tensor]) -> List[torch.Tensor]:
...
@overload
def flatten_bn(
x: Union[List[torch.Tensor], torch.Tensor],
*,
concat: Literal[True],
) -> torch.Tensor:
...
def flatten_bn(
x: Union[List[torch.Tensor], torch.Tensor],
*,
concat: bool = False,
) -> Union[List[torch.Tensor], torch.Tensor]:
"""
Flatten the ``B`` and ``N`` dimensions of batched multimodal inputs.
The input tensor should have shape ``(B, N, ...)```.
"""
if isinstance(x, torch.Tensor):
return x.flatten(0, 1)
if concat:
return torch.cat(x)
return [x_n for x_b in x for x_n in x_b]
def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor:
"""
Recursively concatenates NestedTensors along any heterogeneously sized
@ -93,7 +132,8 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor,
This updates ``inputs_embeds`` in place.
"""
mask = (input_ids == placeholder_token_id)
num_expected_tokens = mask.sum()
num_expected_tokens = mask.sum().item()
assert isinstance(num_expected_tokens, int)
flattened = _flatten_embeddings(multimodal_embeddings)
*dims, embed_dim = flattened.shape

View File

@ -18,7 +18,7 @@ from vllm.utils import JSONTree, is_list_of, json_map_leaves
logger = init_logger(__name__)
NestedTensors = Union[List["NestedTensors"], torch.Tensor]
NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor]
"""
Uses a list instead of a tensor if the dimensions of each element do not match.
"""
@ -61,7 +61,7 @@ class MultiModalInputs(_MultiModalInputsBase):
tensors_ = cast(List[torch.Tensor], stacked)
if any(t.shape != tensors_[0].shape for t in tensors_):
# The tensors have incompatible shapes and can't be stacked.
return stacked
return tensors_
return torch.stack(tensors_)