[Core][VLM] Stack multimodal tensors to represent multiple images within each prompt (#7902)

This commit is contained in:
Peter Salas 2024-08-27 18:53:56 -07:00 committed by GitHub
parent 9c71c97ae2
commit fab5f53e2d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 214 additions and 60 deletions

View File

@ -45,8 +45,6 @@ Base Classes
.. autodata:: vllm.multimodal.NestedTensors .. autodata:: vllm.multimodal.NestedTensors
.. autodata:: vllm.multimodal.BatchedTensors
.. autodata:: vllm.multimodal.BatchedTensorInputs .. autodata:: vllm.multimodal.BatchedTensorInputs
.. autoclass:: vllm.multimodal.MultiModalDataBuiltins .. autoclass:: vllm.multimodal.MultiModalDataBuiltins

View File

@ -0,0 +1,83 @@
import torch
from vllm.multimodal.base import MultiModalInputs, NestedTensors
def assert_nested_tensors_equal(expected: NestedTensors,
actual: NestedTensors):
assert type(expected) == type(actual)
if isinstance(expected, torch.Tensor):
assert torch.equal(expected, actual)
else:
for expected_item, actual_item in zip(expected, actual):
assert_nested_tensors_equal(expected_item, actual_item)
def assert_multimodal_inputs_equal(expected: MultiModalInputs,
actual: MultiModalInputs):
assert set(expected.keys()) == set(actual.keys())
for key in expected:
assert_nested_tensors_equal(expected[key], actual[key])
def test_multimodal_input_batch_single_tensor():
t = torch.rand([1, 2])
result = MultiModalInputs.batch([{"image": t}])
assert_multimodal_inputs_equal(result, {"image": t.unsqueeze(0)})
def test_multimodal_input_batch_multiple_tensors():
a = torch.rand([1, 1, 2])
b = torch.rand([1, 1, 2])
c = torch.rand([1, 1, 2])
result = MultiModalInputs.batch([{"image": a}, {"image": b}, {"image": c}])
assert_multimodal_inputs_equal(result, {"image": torch.stack([a, b, c])})
def test_multimodal_input_batch_multiple_heterogeneous_tensors():
a = torch.rand([1, 2, 2])
b = torch.rand([1, 3, 2])
c = torch.rand([1, 4, 2])
result = MultiModalInputs.batch([{"image": a}, {"image": b}, {"image": c}])
assert_multimodal_inputs_equal(result, {"image": [a, b, c]})
def test_multimodal_input_batch_nested_tensors():
a = torch.rand([2, 3])
b = torch.rand([2, 3])
c = torch.rand([2, 3])
result = MultiModalInputs.batch([{
"image": [a]
}, {
"image": [b]
}, {
"image": [c]
}])
assert_multimodal_inputs_equal(result, {
"image":
torch.stack([a.unsqueeze(0),
b.unsqueeze(0),
c.unsqueeze(0)])
})
def test_multimodal_input_batch_heterogeneous_lists():
a = torch.rand([1, 2, 3])
b = torch.rand([1, 2, 3])
c = torch.rand([1, 2, 3])
result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c]}])
assert_multimodal_inputs_equal(
result,
{"image": [torch.stack([a, b]), c.unsqueeze(0)]})
def test_multimodal_input_batch_multiple_batchable_lists():
a = torch.rand([1, 2, 3])
b = torch.rand([1, 2, 3])
c = torch.rand([1, 2, 3])
d = torch.rand([1, 2, 3])
result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c, d]}])
assert_multimodal_inputs_equal(
result,
{"image": torch.stack([torch.stack([a, b]),
torch.stack([c, d])])})

View File

@ -555,6 +555,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
raise ValueError("Incorrect type of pixel values. " raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}") f"Got type: {type(pixel_values)}")
# Remove the N dimension until multiple images are supported.
pixel_values = pixel_values.squeeze(1)
return Blip2ImagePixelInputs( return Blip2ImagePixelInputs(
type="pixel_values", type="pixel_values",
data=self._validate_pixel_values(pixel_values), data=self._validate_pixel_values(pixel_values),
@ -564,6 +567,10 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
if not isinstance(image_embeds, torch.Tensor): if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. " raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}") f"Got type: {type(image_embeds)}")
# Remove the N dimension until multiple images are supported.
image_embeds = image_embeds.squeeze(1)
return Blip2ImageEmbeddingInputs( return Blip2ImageEmbeddingInputs(
type="image_embeds", type="image_embeds",
data=image_embeds, data=image_embeds,

View File

@ -946,6 +946,9 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal):
raise ValueError("Incorrect type of pixel values. " raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}") f"Got type: {type(pixel_values)}")
# Remove the N dimension until multiple images are supported.
pixel_values = pixel_values.squeeze(1)
return ChameleonImagePixelInputs( return ChameleonImagePixelInputs(
type="pixel_values", type="pixel_values",
data=self._validate_pixel_values(pixel_values), data=self._validate_pixel_values(pixel_values),

View File

@ -249,6 +249,9 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal):
image_patches = kwargs.pop("image_patches", None) image_patches = kwargs.pop("image_patches", None)
if isinstance(image_patches, torch.Tensor): if isinstance(image_patches, torch.Tensor):
# Remove the N dimension until multiple images are supported.
image_patches = image_patches.squeeze(1)
expected_feature_size = self.image_feature_size expected_feature_size = self.image_feature_size
if image_patches.size(-1) != expected_feature_size: if image_patches.size(-1) != expected_feature_size:
raise ValueError( raise ValueError(

View File

@ -244,6 +244,8 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
min_num, min_num,
max_num, max_num,
use_thumbnail=use_thumbnail) use_thumbnail=use_thumbnail)
# Add an N dimension for number of images per prompt (currently 1).
data = data.unsqueeze(0)
model_config = ctx.model_config model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer, tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True) trust_remote_code=True)
@ -410,6 +412,10 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
if not isinstance(image_embeds, torch.Tensor): if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. " raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}") f"Got type: {type(image_embeds)}")
# Flatten the B and N dimensions
image_embeds = image_embeds.flatten(0, 2)
return InternVLImageEmbeddingInputs( return InternVLImageEmbeddingInputs(
type="image_embeds", type="image_embeds",
data=image_embeds, data=image_embeds,
@ -422,6 +428,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
raise ValueError("Incorrect type of pixel values. " raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}") f"Got type: {type(pixel_values)}")
# Flatten the B and N dimensions
pixel_values = pixel_values.flatten(0, 2)
return InternVLImagePixelInputs( return InternVLImagePixelInputs(
type="pixel_values", type="pixel_values",
data=self._validate_pixel_values(pixel_values), data=self._validate_pixel_values(pixel_values),

View File

@ -232,6 +232,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
if not isinstance(pixel_values, torch.Tensor): if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. " raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}") f"Got type: {type(pixel_values)}")
# Remove the N dimension until multiple images are supported.
pixel_values = pixel_values.squeeze(1)
return LlavaImagePixelInputs( return LlavaImagePixelInputs(
type="pixel_values", type="pixel_values",
data=self._validate_pixel_values(pixel_values), data=self._validate_pixel_values(pixel_values),
@ -241,6 +245,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
if not isinstance(image_embeds, torch.Tensor): if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. " raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}") f"Got type: {type(image_embeds)}")
# Remove the N dimension until multiple images are supported.
image_embeds = image_embeds.squeeze(1)
return LlavaImageEmbeddingInputs( return LlavaImageEmbeddingInputs(
type="image_embeds", type="image_embeds",
data=image_embeds, data=image_embeds,

View File

@ -361,6 +361,14 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
raise ValueError("Incorrect type of image sizes. " raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}") f"Got type: {type(image_sizes)}")
# Remove the N dimension until multiple images are supported.
if isinstance(pixel_values, torch.Tensor):
pixel_values = pixel_values.squeeze(1)
else:
pixel_values = [t.squeeze(0) for t in pixel_values]
image_sizes = image_sizes.squeeze(1)
return LlavaNextImagePixelInputs( return LlavaNextImagePixelInputs(
type="pixel_values", type="pixel_values",
data=self._validate_pixel_values(pixel_values), data=self._validate_pixel_values(pixel_values),
@ -372,6 +380,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
raise ValueError("Incorrect type of image embeds. " raise ValueError("Incorrect type of image embeds. "
f"Got type: {type(image_embeds)}") f"Got type: {type(image_embeds)}")
# Remove the N dimension until multiple images are supported.
image_embeds = image_embeds.squeeze(1)
return LlavaNextImageEmbeddingInputs( return LlavaNextImageEmbeddingInputs(
type="image_embeds", type="image_embeds",
data=image_embeds, data=image_embeds,

View File

@ -594,9 +594,14 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
pixel_values_flat: List[torch.Tensor] = [] pixel_values_flat: List[torch.Tensor] = []
tgt_sizes_flat: List[torch.Tensor] = [] tgt_sizes_flat: List[torch.Tensor] = []
for b in range(len(pixel_values)): for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
pixel_values_flat += pixel_values[b] if len(pixel_b) != len(tgt_b):
tgt_sizes_flat += tgt_sizes[b] raise ValueError("Inconsistent N lengths, found: "
f"{len(pixel_b)} vs {len(tgt_b)}")
for pixel_n, tgt_n in zip(pixel_b, tgt_b):
pixel_values_flat += pixel_n
tgt_sizes_flat += tgt_n
# NOTE: Input IDs does not contain image tokens during memory profiling, # NOTE: Input IDs does not contain image tokens during memory profiling,
# so we allow it to be empty # so we allow it to be empty

View File

@ -185,6 +185,10 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
if not isinstance(pixel_values, torch.Tensor): if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. " raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}") f"Got type: {type(pixel_values)}")
# Remove the N dimension until multiple images are supported.
pixel_values = pixel_values.squeeze(1)
return PaliGemmaImagePixelInputs( return PaliGemmaImagePixelInputs(
type="pixel_values", type="pixel_values",
data=self._validate_pixel_values(pixel_values), data=self._validate_pixel_values(pixel_values),
@ -194,6 +198,10 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
if not isinstance(image_embeds, torch.Tensor): if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. " raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}") f"Got type: {type(image_embeds)}")
# Remove the N dimension until multiple images are supported.
image_embeds = image_embeds.squeeze(1)
return PaliGemmaImageEmbeddingInputs( return PaliGemmaImageEmbeddingInputs(
type="image_embeds", type="image_embeds",
data=image_embeds, data=image_embeds,

View File

@ -560,6 +560,14 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
raise ValueError("Incorrect type of image sizes. " raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}") f"Got type: {type(image_sizes)}")
# Merge the B and N dimensions.
if isinstance(pixel_values, torch.Tensor):
pixel_values = pixel_values.flatten(0, 1)
else:
pixel_values = torch.cat(pixel_values)
image_sizes = image_sizes.flatten(0, 1)
return Phi3VImagePixelInputs( return Phi3VImagePixelInputs(
type="pixel_values", type="pixel_values",
data=self._validate_pixel_values(pixel_values), data=self._validate_pixel_values(pixel_values),

View File

@ -333,6 +333,12 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
raise ValueError("Incorrect type of audio features. " raise ValueError("Incorrect type of audio features. "
f"Got type: {type(audio_features)}") f"Got type: {type(audio_features)}")
# Remove the N dimension until multiple audios are supported.
if isinstance(audio_features, torch.Tensor):
audio_features = audio_features.squeeze(1)
else:
audio_features = [t.squeeze(0) for t in audio_features]
return UltravoxAudioFeatureInputs(type="audio_features", return UltravoxAudioFeatureInputs(type="audio_features",
data=audio_features) data=audio_features)
@ -341,6 +347,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
raise ValueError("Incorrect type of audio embeds. " raise ValueError("Incorrect type of audio embeds. "
f"Got type: {type(audio_embeds)}") f"Got type: {type(audio_embeds)}")
# Remove the N dimension until multiple audios are supported.
audio_embeds = audio_embeds.squeeze(1)
return UltravoxAudioEmbeddingInputs(type="audio_embeds", return UltravoxAudioEmbeddingInputs(type="audio_embeds",
data=audio_embeds) data=audio_embeds)

View File

@ -1,5 +1,6 @@
from typing import Dict, Iterable, List, Optional, Protocol, Tuple from typing import Dict, Iterable, List, Optional, Protocol, Tuple
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.func import functional_call from torch.func import functional_call
@ -10,7 +11,7 @@ from vllm.config import (CacheConfig, LoRAConfig, MultiModalConfig,
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.loader import build_model from vllm.model_executor.model_loader.loader import build_model
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.multimodal import BatchedTensors from vllm.multimodal.base import NestedTensors
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
@ -54,9 +55,34 @@ def init_vllm_registered_model(
) )
def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor:
"""
Recursively concatenates NestedTensors along any heterogeneously sized
dimensions.
"""
if isinstance(embeddings, torch.Tensor):
return embeddings
return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))
def _embedding_count_expression(embeddings: NestedTensors) -> str:
"""
Constructs a debugging representation of the number of embeddings in the
NestedTensors.
"""
if isinstance(embeddings, torch.Tensor):
return " x ".join([str(dim) for dim in embeddings.shape[:-1]])
return " + ".join(
_embedding_count_expression(inner) for inner in embeddings)
def merge_multimodal_embeddings(input_ids: torch.Tensor, def merge_multimodal_embeddings(input_ids: torch.Tensor,
inputs_embeds: torch.Tensor, inputs_embeds: torch.Tensor,
multimodal_embeddings: BatchedTensors, multimodal_embeddings: NestedTensors,
placeholder_token_id: int) -> torch.Tensor: placeholder_token_id: int) -> torch.Tensor:
""" """
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
@ -69,28 +95,16 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor,
mask = (input_ids == placeholder_token_id) mask = (input_ids == placeholder_token_id)
num_expected_tokens = mask.sum() num_expected_tokens = mask.sum()
if isinstance(multimodal_embeddings, torch.Tensor): flattened = _flatten_embeddings(multimodal_embeddings)
batch_size, batch_tokens, *_, embed_dim = multimodal_embeddings.shape *dims, embed_dim = flattened.shape
total_tokens = batch_size * batch_tokens num_multimodal_embeddings = np.prod(dims)
if num_expected_tokens != total_tokens: if num_multimodal_embeddings != num_expected_tokens:
expr = f"{batch_size} x {batch_tokens}" expr = _embedding_count_expression(multimodal_embeddings)
raise ValueError( raise ValueError(
f"Attempted to assign {expr} = {total_tokens} " f"Attempted to assign {expr} = {num_multimodal_embeddings} "
f"multimodal tokens to {num_expected_tokens} placeholders") f"multimodal tokens to {num_expected_tokens} placeholders")
inputs_embeds[mask] = multimodal_embeddings.view(
total_tokens, embed_dim)
else:
size_per_batch = [t.shape[0] for t in multimodal_embeddings]
total_tokens = sum(size_per_batch)
if num_expected_tokens != total_tokens:
expr = ' + '.join(map(str, size_per_batch))
raise ValueError(
f"Attempted to assign {expr} = {total_tokens} "
f"multimodal tokens to {num_expected_tokens} placeholders")
inputs_embeds[mask] = torch.cat(multimodal_embeddings)
inputs_embeds[mask] = flattened.view(num_expected_tokens, embed_dim)
return inputs_embeds return inputs_embeds

View File

@ -1,4 +1,4 @@
from .base import (BatchedTensorInputs, BatchedTensors, MultiModalDataBuiltins, from .base import (BatchedTensorInputs, MultiModalDataBuiltins,
MultiModalDataDict, MultiModalInputs, MultiModalPlugin, MultiModalDataDict, MultiModalInputs, MultiModalPlugin,
NestedTensors) NestedTensors)
from .registry import MultiModalRegistry from .registry import MultiModalRegistry
@ -14,7 +14,6 @@ See also:
__all__ = [ __all__ = [
"BatchedTensorInputs", "BatchedTensorInputs",
"BatchedTensors",
"MultiModalDataBuiltins", "MultiModalDataBuiltins",
"MultiModalDataDict", "MultiModalDataDict",
"MultiModalInputs", "MultiModalInputs",

View File

@ -1,9 +1,8 @@
import sys import sys
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import UserDict, defaultdict from collections import UserDict, defaultdict
from typing import Callable, Dict, List, Mapping, Optional from typing import (Callable, Dict, List, Mapping, Optional, Tuple, Type,
from typing import Sequence as GenericSequence TypedDict, TypeVar, Union, cast, final)
from typing import Tuple, Type, TypedDict, TypeVar, Union, cast, final
import numpy as np import numpy as np
import torch import torch
@ -15,23 +14,16 @@ from typing_extensions import TypeAlias
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.inputs import InputContext from vllm.inputs import InputContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import JSONTree, json_map_leaves from vllm.utils import json_map_leaves
logger = init_logger(__name__) logger = init_logger(__name__)
NestedTensors = Union[GenericSequence[torch.Tensor], torch.Tensor] NestedTensors = Union[List["NestedTensors"], torch.Tensor]
""" """
Use a list instead of a tensor if the dimensions of each element do not match. Uses a list instead of a tensor if the dimensions of each element do not match.
Currently only supports up to singly nested list of tensors.
""" """
BatchedTensors: TypeAlias = JSONTree[torch.Tensor] BatchedTensorInputs: TypeAlias = Dict[str, NestedTensors]
"""
A nested JSON structure of tensors which have been batched via
:meth:`MultiModalInputs.batch`.
"""
BatchedTensorInputs: TypeAlias = Dict[str, JSONTree[torch.Tensor]]
""" """
A dictionary containing nested tensors which have been batched via A dictionary containing nested tensors which have been batched via
:meth:`MultiModalInputs.batch`. :meth:`MultiModalInputs.batch`.
@ -54,26 +46,23 @@ class MultiModalInputs(_MultiModalInputsBase):
""" """
@staticmethod @staticmethod
def _try_concat(tensors: List[NestedTensors]) -> BatchedTensors: def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
""" """
If each input tensor in the batch has the same shape, return a single Recursively stacks lists of tensors when they all have the same shape.
batched tensor; otherwise, return a list of :class:`NestedTensors` with
one element per item in the batch.
""" """
# may be list rather than tensors if isinstance(nested_tensors, torch.Tensor):
if isinstance(tensors[0], list): return nested_tensors
return [[t for t in tensor[0]]
for tensor in cast(List[List[torch.Tensor]], tensors)]
tensors_ = cast(List[torch.Tensor], tensors) stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors]
if any(isinstance(t, list) for t in stacked):
return stacked
unbatched_shape = tensors_[0].shape[1:] tensors_ = cast(List[torch.Tensor], stacked)
if any(t.shape != tensors_[0].shape for t in tensors_):
# The tensors have incompatible shapes and can't be stacked.
return tensors_
for tensor in tensors_: return torch.stack(tensors_)
if tensor.shape[1:] != unbatched_shape:
return [tensor.squeeze(0) for tensor in tensors_]
return torch.cat(tensors_, dim=0)
@staticmethod @staticmethod
def batch(inputs_list: List["MultiModalInputs"]) -> BatchedTensorInputs: def batch(inputs_list: List["MultiModalInputs"]) -> BatchedTensorInputs:
@ -102,7 +91,7 @@ class MultiModalInputs(_MultiModalInputsBase):
item_lists[k].append(v) item_lists[k].append(v)
return { return {
k: MultiModalInputs._try_concat(item_list) k: MultiModalInputs._try_stack(item_list)
for k, item_list in item_lists.items() for k, item_list in item_lists.items()
} }