[Bugfix][VLM] Fix incompatibility between #7902 and #7230 (#7948)

This commit is contained in:
Cyrus Leung 2024-08-28 23:11:18 +08:00 committed by GitHub
parent 98c12cffe5
commit ef9baee3c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 120 additions and 92 deletions

View File

@ -40,13 +40,13 @@ BLIP2_IMAGE_TOKEN_ID = 50265
class Blip2ImagePixelInputs(TypedDict): class Blip2ImagePixelInputs(TypedDict):
type: Literal["pixel_values"] type: Literal["pixel_values"]
data: torch.Tensor data: torch.Tensor
"""Shape: (batch_size, num_channels, height, width)""" """Shape: `(batch_size * num_images, num_channels, height, width)`"""
class Blip2ImageEmbeddingInputs(TypedDict): class Blip2ImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"] type: Literal["image_embeds"]
data: torch.Tensor data: torch.Tensor
"""Shape: `(batch_size, image_feature_size, hidden_size)` """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone. `hidden_size` must match the hidden size of language model backbone.
""" """

View File

@ -53,7 +53,7 @@ CHAMELEON_SEP_TOKEN_ID = 8710
class ChameleonImagePixelInputs(TypedDict): class ChameleonImagePixelInputs(TypedDict):
type: Literal["pixel_values"] type: Literal["pixel_values"]
data: torch.Tensor data: torch.Tensor
"""Shape: `(batch_size, num_channels, height, width)`""" """Shape: `(batch_size * num_images, num_channels, height, width)`"""
def get_max_chameleon_image_tokens(ctx: InputContext): def get_max_chameleon_image_tokens(ctx: InputContext):

View File

@ -29,7 +29,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
get_clip_num_patches) get_clip_num_patches)
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal
from .utils import (filter_weights, init_vllm_registered_model, from .utils import (filter_weights, flatten_bn, init_vllm_registered_model,
merge_multimodal_embeddings) merge_multimodal_embeddings)
IMG_START = '<img>' IMG_START = '<img>'
@ -42,19 +42,17 @@ IMAGENET_STD = (0.229, 0.224, 0.225)
class InternVLImagePixelInputs(TypedDict): class InternVLImagePixelInputs(TypedDict):
type: Literal["pixel_values"] type: Literal["pixel_values"]
data: Union[torch.Tensor, List[torch.Tensor]] data: torch.Tensor
""" """
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)` Shape:
`(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
Note that `num_patches` may be different for each batch, in which case
the data is passed as a list instead of a batched tensor.
""" """
class InternVLImageEmbeddingInputs(TypedDict): class InternVLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"] type: Literal["image_embeds"]
data: Union[torch.Tensor, List[torch.Tensor]] data: torch.Tensor
"""Shape: `(batch_size, image_feature_size, hidden_size)` """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone. `hidden_size` must match the hidden size of language model backbone.
""" """
@ -357,7 +355,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
x = x.permute(0, 2, 1, 3).contiguous() x = x.permute(0, 2, 1, 3).contiguous()
return x return x
def extract_feature(self, pixel_values): def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
vit_embeds = self.vision_model(pixel_values=pixel_values) vit_embeds = self.vision_model(pixel_values=pixel_values)
vit_embeds = vit_embeds[:, 1:, :] vit_embeds = vit_embeds[:, 1:, :]
@ -370,17 +368,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
vit_embeds = self.mlp1(vit_embeds) vit_embeds = self.mlp1(vit_embeds)
return vit_embeds return vit_embeds
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
if list(data.shape[1:]) != [2]:
raise ValueError(
f"The expected image sizes shape is batch dimension plus "
f"{[2]}. You supplied {data.shape}.")
return data
def _validate_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
h = w = self.config.vision_config.image_size h = w = self.config.vision_config.image_size
expected_dims = (3, h, w) expected_dims = (3, h, w)
@ -389,10 +377,11 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
actual_dims = tuple(d.shape) actual_dims = tuple(d.shape)
if actual_dims != expected_dims: if actual_dims != expected_dims:
expected_expr = ("num_patches", *map(str, expected_dims)) expected_expr = str(expected_dims)
raise ValueError( raise ValueError(
"The expected shape of pixel values in each batch element " "The expected shape of pixel values per image per batch "
f"is {expected_expr}. You supplied {tuple(d.shape)}.") f" per patch is {expected_expr}. "
f"You supplied {tuple(d.shape)}.")
for d in data: for d in data:
_validate_shape(d) _validate_shape(d)
@ -413,12 +402,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
raise ValueError("Incorrect type of image embeddings. " raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}") f"Got type: {type(image_embeds)}")
# Flatten the B and N dimensions
image_embeds = image_embeds.flatten(0, 2)
return InternVLImageEmbeddingInputs( return InternVLImageEmbeddingInputs(
type="image_embeds", type="image_embeds",
data=image_embeds, data=flatten_bn(image_embeds),
) )
self.img_context_token_id = image_token_id[0] self.img_context_token_id = image_token_id[0]
@ -428,12 +414,10 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
raise ValueError("Incorrect type of pixel values. " raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}") f"Got type: {type(pixel_values)}")
# Flatten the B and N dimensions
pixel_values = pixel_values.flatten(0, 2)
return InternVLImagePixelInputs( return InternVLImagePixelInputs(
type="pixel_values", type="pixel_values",
data=self._validate_pixel_values(pixel_values), data=self._validate_pixel_values(
flatten_bn(pixel_values, concat=True).flatten(0, 1)),
) )
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")

View File

@ -30,13 +30,13 @@ from .utils import (filter_weights, init_vllm_registered_model,
class LlavaImagePixelInputs(TypedDict): class LlavaImagePixelInputs(TypedDict):
type: Literal["pixel_values"] type: Literal["pixel_values"]
data: torch.Tensor data: torch.Tensor
"""Shape: `(batch_size, num_channels, height, width)`""" """Shape: `(batch_size * num_images, num_channels, height, width)`"""
class LlavaImageEmbeddingInputs(TypedDict): class LlavaImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"] type: Literal["image_embeds"]
data: torch.Tensor data: torch.Tensor
"""Shape: `(batch_size, image_feature_size, hidden_size)` """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone. `hidden_size` must match the hidden size of language model backbone.
""" """

View File

@ -29,7 +29,7 @@ from .llava import LlavaMultiModalProjector
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_siglip_image_feature_size, dummy_seq_data_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip) get_siglip_patch_grid_length, input_processor_for_siglip)
from .utils import (filter_weights, init_vllm_registered_model, from .utils import (filter_weights, flatten_bn, init_vllm_registered_model,
merge_multimodal_embeddings) merge_multimodal_embeddings)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -47,15 +47,16 @@ class LlavaNextImagePixelInputs(TypedDict):
type: Literal["pixel_values"] type: Literal["pixel_values"]
data: Union[torch.Tensor, List[torch.Tensor]] data: Union[torch.Tensor, List[torch.Tensor]]
""" """
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)` Shape:
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
Note that `num_patches` may be different for each batch, in which case Note that `num_patches` may be different per batch and image,
the data is passed as a list instead of a batched tensor. in which case the data is passed as a list instead of a batched tensor.
""" """
image_sizes: NotRequired[torch.Tensor] image_sizes: NotRequired[torch.Tensor]
""" """
Shape: `(batch_size, 2)` Shape: `(batch_size * num_images, 2)`
This should be in `(height, width)` format. This should be in `(height, width)` format.
""" """
@ -64,7 +65,7 @@ class LlavaNextImagePixelInputs(TypedDict):
class LlavaNextImageEmbeddingInputs(TypedDict): class LlavaNextImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"] type: Literal["image_embeds"]
data: torch.Tensor data: torch.Tensor
"""Shape: `(batch_size, image_feature_size, hidden_size)` """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone. `hidden_size` must match the hidden size of language model backbone.
""" """
@ -315,10 +316,19 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
torch.empty(config.text_config.hidden_size)) torch.empty(config.text_config.hidden_size))
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
if list(data.shape[1:]) != [2]: expected_dims = (2, )
raise ValueError(
f"The expected image sizes shape is batch dimension plus " def _validate_shape(d: torch.Tensor):
f"{[2]}. You supplied {data.shape}.") actual_dims = tuple(d.shape)
if actual_dims != expected_dims:
expected_expr = str(expected_dims)
raise ValueError(
f"The expected shape of image sizes per image per batch "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data return data
@ -335,7 +345,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
if actual_dims != expected_dims: if actual_dims != expected_dims:
expected_expr = ("num_patches", *map(str, expected_dims)) expected_expr = ("num_patches", *map(str, expected_dims))
raise ValueError( raise ValueError(
"The expected shape of pixel values in each batch element " "The expected shape of pixel values per image per batch "
f"is {expected_expr}. You supplied {tuple(d.shape)}.") f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data: for d in data:
@ -357,22 +367,15 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
raise ValueError("Incorrect type of pixel values. " raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}") f"Got type: {type(pixel_values)}")
if not isinstance(image_sizes, torch.Tensor): if not isinstance(image_sizes, (torch.Tensor, list)):
raise ValueError("Incorrect type of image sizes. " raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}") f"Got type: {type(image_sizes)}")
# Remove the N dimension until multiple images are supported.
if isinstance(pixel_values, torch.Tensor):
pixel_values = pixel_values.squeeze(1)
else:
pixel_values = [t.squeeze(0) for t in pixel_values]
image_sizes = image_sizes.squeeze(1)
return LlavaNextImagePixelInputs( return LlavaNextImagePixelInputs(
type="pixel_values", type="pixel_values",
data=self._validate_pixel_values(pixel_values), data=self._validate_pixel_values(flatten_bn(pixel_values)),
image_sizes=self._validate_image_sizes(image_sizes), image_sizes=self._validate_image_sizes(
flatten_bn(image_sizes, concat=True)),
) )
if image_embeds is not None: if image_embeds is not None:
@ -380,12 +383,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
raise ValueError("Incorrect type of image embeds. " raise ValueError("Incorrect type of image embeds. "
f"Got type: {type(image_embeds)}") f"Got type: {type(image_embeds)}")
# Remove the N dimension until multiple images are supported.
image_embeds = image_embeds.squeeze(1)
return LlavaNextImageEmbeddingInputs( return LlavaNextImageEmbeddingInputs(
type="image_embeds", type="image_embeds",
data=image_embeds, data=flatten_bn(image_embeds),
) )
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")

View File

@ -34,13 +34,13 @@ _KEYS_TO_MODIFY_MAPPING = {
class PaliGemmaImagePixelInputs(TypedDict): class PaliGemmaImagePixelInputs(TypedDict):
type: Literal["pixel_values"] type: Literal["pixel_values"]
data: torch.Tensor data: torch.Tensor
"""Shape: (batch_size, num_channels, height, width)""" """Shape: `(batch_size * num_images, num_channels, height, width)`"""
class PaliGemmaImageEmbeddingInputs(TypedDict): class PaliGemmaImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"] type: Literal["image_embeds"]
data: torch.Tensor data: torch.Tensor
"""Shape: `(batch_size, image_feature_size, hidden_size)` """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone. `hidden_size` must match the hidden size of language model backbone.
""" """

View File

@ -44,7 +44,7 @@ from vllm.utils import is_list_of
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal
from .utils import merge_multimodal_embeddings from .utils import flatten_bn, merge_multimodal_embeddings
logger = init_logger(__name__) logger = init_logger(__name__)
@ -75,15 +75,16 @@ class Phi3VImagePixelInputs(TypedDict):
type: Literal["pixel_values"] type: Literal["pixel_values"]
data: Union[torch.Tensor, List[torch.Tensor]] data: Union[torch.Tensor, List[torch.Tensor]]
""" """
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)` Shape:
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
Note that `num_patches` may be different for each batch, in which case Note that `num_patches` may be different per batch and image,
the data is passed as a list instead of a batched tensor. in which case the data is passed as a list instead of a batched tensor.
""" """
image_sizes: torch.Tensor image_sizes: torch.Tensor
""" """
Shape: `(batch_size, 2)` Shape: `(batch_size * num_images, 2)`
This should be in `(height, width)` format. This should be in `(height, width)` format.
""" """
@ -92,7 +93,7 @@ class Phi3VImagePixelInputs(TypedDict):
class Phi3VImageEmbeddingInputs(TypedDict): class Phi3VImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"] type: Literal["image_embeds"]
data: Union[torch.Tensor, List[torch.Tensor]] data: Union[torch.Tensor, List[torch.Tensor]]
"""Shape: `(batch_size, image_feature_size, hidden_size)` """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone. `hidden_size` must match the hidden size of language model backbone.
""" """
@ -511,10 +512,19 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
self.sampler = Sampler() self.sampler = Sampler()
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
if list(data.shape[1:]) != [2]: expected_dims = (2, )
raise ValueError(
f"The expected shape of image sizes is batch dimension plus " def _validate_shape(d: torch.Tensor):
f"{[2]}. You supplied {tuple(data.shape)}.") actual_dims = tuple(d.shape)
if actual_dims != expected_dims:
expected_expr = str(expected_dims)
raise ValueError(
f"The expected shape of image sizes per image per batch "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data return data
@ -531,7 +541,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
if actual_dims != expected_dims: if actual_dims != expected_dims:
expected_expr = ("num_patches", *map(str, expected_dims)) expected_expr = ("num_patches", *map(str, expected_dims))
raise ValueError( raise ValueError(
"The expected shape of pixel values in each batch element " "The expected shape of pixel values per image per batch "
f"is {expected_expr}. You supplied {tuple(d.shape)}.") f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data: for d in data:
@ -556,30 +566,24 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
raise ValueError("Incorrect type of pixel values. " raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}") f"Got type: {type(pixel_values)}")
if not isinstance(image_sizes, torch.Tensor): if not isinstance(image_sizes, (torch.Tensor, list)):
raise ValueError("Incorrect type of image sizes. " raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}") f"Got type: {type(image_sizes)}")
# Merge the B and N dimensions.
if isinstance(pixel_values, torch.Tensor):
pixel_values = pixel_values.flatten(0, 1)
else:
pixel_values = torch.cat(pixel_values)
image_sizes = image_sizes.flatten(0, 1)
return Phi3VImagePixelInputs( return Phi3VImagePixelInputs(
type="pixel_values", type="pixel_values",
data=self._validate_pixel_values(pixel_values), data=self._validate_pixel_values(flatten_bn(pixel_values)),
image_sizes=self._validate_image_sizes(image_sizes)) image_sizes=self._validate_image_sizes(
flatten_bn(image_sizes, concat=True)))
if image_embeds is not None: if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor): if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. " raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}") f"Got type: {type(image_embeds)}")
return Phi3VImageEmbeddingInputs( return Phi3VImageEmbeddingInputs(
type="image_embeds", type="image_embeds",
data=image_embeds, data=flatten_bn(image_embeds),
) )
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")

View File

@ -49,7 +49,7 @@ logger = init_logger(__name__)
class UltravoxAudioFeatureInputs(TypedDict): class UltravoxAudioFeatureInputs(TypedDict):
type: Literal["audio_features"] type: Literal["audio_features"]
data: Union[torch.Tensor, List[torch.Tensor]] data: Union[torch.Tensor, List[torch.Tensor]]
"""Shape: `(batch_size, 80, M)""" """Shape: `(batch_size * num_audios, 80, M)"""
class UltravoxAudioEmbeddingInputs(TypedDict): class UltravoxAudioEmbeddingInputs(TypedDict):

View File

@ -1,4 +1,5 @@
from typing import Dict, Iterable, List, Optional, Protocol, Tuple from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple,
Union, overload)
import numpy as np import numpy as np
import torch import torch
@ -55,6 +56,44 @@ def init_vllm_registered_model(
) )
@overload
def flatten_bn(x: torch.Tensor) -> torch.Tensor:
...
@overload
def flatten_bn(x: List[torch.Tensor]) -> List[torch.Tensor]:
...
@overload
def flatten_bn(
x: Union[List[torch.Tensor], torch.Tensor],
*,
concat: Literal[True],
) -> torch.Tensor:
...
def flatten_bn(
x: Union[List[torch.Tensor], torch.Tensor],
*,
concat: bool = False,
) -> Union[List[torch.Tensor], torch.Tensor]:
"""
Flatten the ``B`` and ``N`` dimensions of batched multimodal inputs.
The input tensor should have shape ``(B, N, ...)```.
"""
if isinstance(x, torch.Tensor):
return x.flatten(0, 1)
if concat:
return torch.cat(x)
return [x_n for x_b in x for x_n in x_b]
def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor: def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor:
""" """
Recursively concatenates NestedTensors along any heterogeneously sized Recursively concatenates NestedTensors along any heterogeneously sized
@ -93,7 +132,8 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor,
This updates ``inputs_embeds`` in place. This updates ``inputs_embeds`` in place.
""" """
mask = (input_ids == placeholder_token_id) mask = (input_ids == placeholder_token_id)
num_expected_tokens = mask.sum() num_expected_tokens = mask.sum().item()
assert isinstance(num_expected_tokens, int)
flattened = _flatten_embeddings(multimodal_embeddings) flattened = _flatten_embeddings(multimodal_embeddings)
*dims, embed_dim = flattened.shape *dims, embed_dim = flattened.shape

View File

@ -18,7 +18,7 @@ from vllm.utils import JSONTree, is_list_of, json_map_leaves
logger = init_logger(__name__) logger = init_logger(__name__)
NestedTensors = Union[List["NestedTensors"], torch.Tensor] NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor]
""" """
Uses a list instead of a tensor if the dimensions of each element do not match. Uses a list instead of a tensor if the dimensions of each element do not match.
""" """
@ -61,7 +61,7 @@ class MultiModalInputs(_MultiModalInputsBase):
tensors_ = cast(List[torch.Tensor], stacked) tensors_ = cast(List[torch.Tensor], stacked)
if any(t.shape != tensors_[0].shape for t in tensors_): if any(t.shape != tensors_[0].shape for t in tensors_):
# The tensors have incompatible shapes and can't be stacked. # The tensors have incompatible shapes and can't be stacked.
return stacked return tensors_
return torch.stack(tensors_) return torch.stack(tensors_)