mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 19:35:01 +08:00
[Core][VLM] Stack multimodal tensors to represent multiple images within each prompt (#7902)
This commit is contained in:
parent
9c71c97ae2
commit
fab5f53e2d
@ -45,8 +45,6 @@ Base Classes
|
|||||||
|
|
||||||
.. autodata:: vllm.multimodal.NestedTensors
|
.. autodata:: vllm.multimodal.NestedTensors
|
||||||
|
|
||||||
.. autodata:: vllm.multimodal.BatchedTensors
|
|
||||||
|
|
||||||
.. autodata:: vllm.multimodal.BatchedTensorInputs
|
.. autodata:: vllm.multimodal.BatchedTensorInputs
|
||||||
|
|
||||||
.. autoclass:: vllm.multimodal.MultiModalDataBuiltins
|
.. autoclass:: vllm.multimodal.MultiModalDataBuiltins
|
||||||
|
|||||||
83
tests/multimodal/test_base.py
Normal file
83
tests/multimodal/test_base.py
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.multimodal.base import MultiModalInputs, NestedTensors
|
||||||
|
|
||||||
|
|
||||||
|
def assert_nested_tensors_equal(expected: NestedTensors,
|
||||||
|
actual: NestedTensors):
|
||||||
|
assert type(expected) == type(actual)
|
||||||
|
if isinstance(expected, torch.Tensor):
|
||||||
|
assert torch.equal(expected, actual)
|
||||||
|
else:
|
||||||
|
for expected_item, actual_item in zip(expected, actual):
|
||||||
|
assert_nested_tensors_equal(expected_item, actual_item)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_multimodal_inputs_equal(expected: MultiModalInputs,
|
||||||
|
actual: MultiModalInputs):
|
||||||
|
assert set(expected.keys()) == set(actual.keys())
|
||||||
|
for key in expected:
|
||||||
|
assert_nested_tensors_equal(expected[key], actual[key])
|
||||||
|
|
||||||
|
|
||||||
|
def test_multimodal_input_batch_single_tensor():
|
||||||
|
t = torch.rand([1, 2])
|
||||||
|
result = MultiModalInputs.batch([{"image": t}])
|
||||||
|
assert_multimodal_inputs_equal(result, {"image": t.unsqueeze(0)})
|
||||||
|
|
||||||
|
|
||||||
|
def test_multimodal_input_batch_multiple_tensors():
|
||||||
|
a = torch.rand([1, 1, 2])
|
||||||
|
b = torch.rand([1, 1, 2])
|
||||||
|
c = torch.rand([1, 1, 2])
|
||||||
|
result = MultiModalInputs.batch([{"image": a}, {"image": b}, {"image": c}])
|
||||||
|
assert_multimodal_inputs_equal(result, {"image": torch.stack([a, b, c])})
|
||||||
|
|
||||||
|
|
||||||
|
def test_multimodal_input_batch_multiple_heterogeneous_tensors():
|
||||||
|
a = torch.rand([1, 2, 2])
|
||||||
|
b = torch.rand([1, 3, 2])
|
||||||
|
c = torch.rand([1, 4, 2])
|
||||||
|
result = MultiModalInputs.batch([{"image": a}, {"image": b}, {"image": c}])
|
||||||
|
assert_multimodal_inputs_equal(result, {"image": [a, b, c]})
|
||||||
|
|
||||||
|
|
||||||
|
def test_multimodal_input_batch_nested_tensors():
|
||||||
|
a = torch.rand([2, 3])
|
||||||
|
b = torch.rand([2, 3])
|
||||||
|
c = torch.rand([2, 3])
|
||||||
|
result = MultiModalInputs.batch([{
|
||||||
|
"image": [a]
|
||||||
|
}, {
|
||||||
|
"image": [b]
|
||||||
|
}, {
|
||||||
|
"image": [c]
|
||||||
|
}])
|
||||||
|
assert_multimodal_inputs_equal(result, {
|
||||||
|
"image":
|
||||||
|
torch.stack([a.unsqueeze(0),
|
||||||
|
b.unsqueeze(0),
|
||||||
|
c.unsqueeze(0)])
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def test_multimodal_input_batch_heterogeneous_lists():
|
||||||
|
a = torch.rand([1, 2, 3])
|
||||||
|
b = torch.rand([1, 2, 3])
|
||||||
|
c = torch.rand([1, 2, 3])
|
||||||
|
result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c]}])
|
||||||
|
assert_multimodal_inputs_equal(
|
||||||
|
result,
|
||||||
|
{"image": [torch.stack([a, b]), c.unsqueeze(0)]})
|
||||||
|
|
||||||
|
|
||||||
|
def test_multimodal_input_batch_multiple_batchable_lists():
|
||||||
|
a = torch.rand([1, 2, 3])
|
||||||
|
b = torch.rand([1, 2, 3])
|
||||||
|
c = torch.rand([1, 2, 3])
|
||||||
|
d = torch.rand([1, 2, 3])
|
||||||
|
result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c, d]}])
|
||||||
|
assert_multimodal_inputs_equal(
|
||||||
|
result,
|
||||||
|
{"image": torch.stack([torch.stack([a, b]),
|
||||||
|
torch.stack([c, d])])})
|
||||||
@ -555,6 +555,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
raise ValueError("Incorrect type of pixel values. "
|
raise ValueError("Incorrect type of pixel values. "
|
||||||
f"Got type: {type(pixel_values)}")
|
f"Got type: {type(pixel_values)}")
|
||||||
|
|
||||||
|
# Remove the N dimension until multiple images are supported.
|
||||||
|
pixel_values = pixel_values.squeeze(1)
|
||||||
|
|
||||||
return Blip2ImagePixelInputs(
|
return Blip2ImagePixelInputs(
|
||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
data=self._validate_pixel_values(pixel_values),
|
data=self._validate_pixel_values(pixel_values),
|
||||||
@ -564,6 +567,10 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
if not isinstance(image_embeds, torch.Tensor):
|
if not isinstance(image_embeds, torch.Tensor):
|
||||||
raise ValueError("Incorrect type of image embeddings. "
|
raise ValueError("Incorrect type of image embeddings. "
|
||||||
f"Got type: {type(image_embeds)}")
|
f"Got type: {type(image_embeds)}")
|
||||||
|
|
||||||
|
# Remove the N dimension until multiple images are supported.
|
||||||
|
image_embeds = image_embeds.squeeze(1)
|
||||||
|
|
||||||
return Blip2ImageEmbeddingInputs(
|
return Blip2ImageEmbeddingInputs(
|
||||||
type="image_embeds",
|
type="image_embeds",
|
||||||
data=image_embeds,
|
data=image_embeds,
|
||||||
|
|||||||
@ -946,6 +946,9 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
raise ValueError("Incorrect type of pixel values. "
|
raise ValueError("Incorrect type of pixel values. "
|
||||||
f"Got type: {type(pixel_values)}")
|
f"Got type: {type(pixel_values)}")
|
||||||
|
|
||||||
|
# Remove the N dimension until multiple images are supported.
|
||||||
|
pixel_values = pixel_values.squeeze(1)
|
||||||
|
|
||||||
return ChameleonImagePixelInputs(
|
return ChameleonImagePixelInputs(
|
||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
data=self._validate_pixel_values(pixel_values),
|
data=self._validate_pixel_values(pixel_values),
|
||||||
|
|||||||
@ -249,6 +249,9 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal):
|
|||||||
image_patches = kwargs.pop("image_patches", None)
|
image_patches = kwargs.pop("image_patches", None)
|
||||||
|
|
||||||
if isinstance(image_patches, torch.Tensor):
|
if isinstance(image_patches, torch.Tensor):
|
||||||
|
# Remove the N dimension until multiple images are supported.
|
||||||
|
image_patches = image_patches.squeeze(1)
|
||||||
|
|
||||||
expected_feature_size = self.image_feature_size
|
expected_feature_size = self.image_feature_size
|
||||||
if image_patches.size(-1) != expected_feature_size:
|
if image_patches.size(-1) != expected_feature_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@ -244,6 +244,8 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
|
|||||||
min_num,
|
min_num,
|
||||||
max_num,
|
max_num,
|
||||||
use_thumbnail=use_thumbnail)
|
use_thumbnail=use_thumbnail)
|
||||||
|
# Add an N dimension for number of images per prompt (currently 1).
|
||||||
|
data = data.unsqueeze(0)
|
||||||
model_config = ctx.model_config
|
model_config = ctx.model_config
|
||||||
tokenizer = cached_get_tokenizer(model_config.tokenizer,
|
tokenizer = cached_get_tokenizer(model_config.tokenizer,
|
||||||
trust_remote_code=True)
|
trust_remote_code=True)
|
||||||
@ -410,6 +412,10 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
|
|||||||
if not isinstance(image_embeds, torch.Tensor):
|
if not isinstance(image_embeds, torch.Tensor):
|
||||||
raise ValueError("Incorrect type of image embeddings. "
|
raise ValueError("Incorrect type of image embeddings. "
|
||||||
f"Got type: {type(image_embeds)}")
|
f"Got type: {type(image_embeds)}")
|
||||||
|
|
||||||
|
# Flatten the B and N dimensions
|
||||||
|
image_embeds = image_embeds.flatten(0, 2)
|
||||||
|
|
||||||
return InternVLImageEmbeddingInputs(
|
return InternVLImageEmbeddingInputs(
|
||||||
type="image_embeds",
|
type="image_embeds",
|
||||||
data=image_embeds,
|
data=image_embeds,
|
||||||
@ -422,6 +428,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
|
|||||||
raise ValueError("Incorrect type of pixel values. "
|
raise ValueError("Incorrect type of pixel values. "
|
||||||
f"Got type: {type(pixel_values)}")
|
f"Got type: {type(pixel_values)}")
|
||||||
|
|
||||||
|
# Flatten the B and N dimensions
|
||||||
|
pixel_values = pixel_values.flatten(0, 2)
|
||||||
|
|
||||||
return InternVLImagePixelInputs(
|
return InternVLImagePixelInputs(
|
||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
data=self._validate_pixel_values(pixel_values),
|
data=self._validate_pixel_values(pixel_values),
|
||||||
|
|||||||
@ -232,6 +232,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
if not isinstance(pixel_values, torch.Tensor):
|
if not isinstance(pixel_values, torch.Tensor):
|
||||||
raise ValueError("Incorrect type of pixel values. "
|
raise ValueError("Incorrect type of pixel values. "
|
||||||
f"Got type: {type(pixel_values)}")
|
f"Got type: {type(pixel_values)}")
|
||||||
|
|
||||||
|
# Remove the N dimension until multiple images are supported.
|
||||||
|
pixel_values = pixel_values.squeeze(1)
|
||||||
|
|
||||||
return LlavaImagePixelInputs(
|
return LlavaImagePixelInputs(
|
||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
data=self._validate_pixel_values(pixel_values),
|
data=self._validate_pixel_values(pixel_values),
|
||||||
@ -241,6 +245,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
if not isinstance(image_embeds, torch.Tensor):
|
if not isinstance(image_embeds, torch.Tensor):
|
||||||
raise ValueError("Incorrect type of image embeddings. "
|
raise ValueError("Incorrect type of image embeddings. "
|
||||||
f"Got type: {type(image_embeds)}")
|
f"Got type: {type(image_embeds)}")
|
||||||
|
|
||||||
|
# Remove the N dimension until multiple images are supported.
|
||||||
|
image_embeds = image_embeds.squeeze(1)
|
||||||
|
|
||||||
return LlavaImageEmbeddingInputs(
|
return LlavaImageEmbeddingInputs(
|
||||||
type="image_embeds",
|
type="image_embeds",
|
||||||
data=image_embeds,
|
data=image_embeds,
|
||||||
|
|||||||
@ -361,6 +361,14 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
raise ValueError("Incorrect type of image sizes. "
|
raise ValueError("Incorrect type of image sizes. "
|
||||||
f"Got type: {type(image_sizes)}")
|
f"Got type: {type(image_sizes)}")
|
||||||
|
|
||||||
|
# Remove the N dimension until multiple images are supported.
|
||||||
|
if isinstance(pixel_values, torch.Tensor):
|
||||||
|
pixel_values = pixel_values.squeeze(1)
|
||||||
|
else:
|
||||||
|
pixel_values = [t.squeeze(0) for t in pixel_values]
|
||||||
|
|
||||||
|
image_sizes = image_sizes.squeeze(1)
|
||||||
|
|
||||||
return LlavaNextImagePixelInputs(
|
return LlavaNextImagePixelInputs(
|
||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
data=self._validate_pixel_values(pixel_values),
|
data=self._validate_pixel_values(pixel_values),
|
||||||
@ -372,6 +380,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
raise ValueError("Incorrect type of image embeds. "
|
raise ValueError("Incorrect type of image embeds. "
|
||||||
f"Got type: {type(image_embeds)}")
|
f"Got type: {type(image_embeds)}")
|
||||||
|
|
||||||
|
# Remove the N dimension until multiple images are supported.
|
||||||
|
image_embeds = image_embeds.squeeze(1)
|
||||||
|
|
||||||
return LlavaNextImageEmbeddingInputs(
|
return LlavaNextImageEmbeddingInputs(
|
||||||
type="image_embeds",
|
type="image_embeds",
|
||||||
data=image_embeds,
|
data=image_embeds,
|
||||||
|
|||||||
@ -594,9 +594,14 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
|
|||||||
|
|
||||||
pixel_values_flat: List[torch.Tensor] = []
|
pixel_values_flat: List[torch.Tensor] = []
|
||||||
tgt_sizes_flat: List[torch.Tensor] = []
|
tgt_sizes_flat: List[torch.Tensor] = []
|
||||||
for b in range(len(pixel_values)):
|
for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
|
||||||
pixel_values_flat += pixel_values[b]
|
if len(pixel_b) != len(tgt_b):
|
||||||
tgt_sizes_flat += tgt_sizes[b]
|
raise ValueError("Inconsistent N lengths, found: "
|
||||||
|
f"{len(pixel_b)} vs {len(tgt_b)}")
|
||||||
|
|
||||||
|
for pixel_n, tgt_n in zip(pixel_b, tgt_b):
|
||||||
|
pixel_values_flat += pixel_n
|
||||||
|
tgt_sizes_flat += tgt_n
|
||||||
|
|
||||||
# NOTE: Input IDs does not contain image tokens during memory profiling,
|
# NOTE: Input IDs does not contain image tokens during memory profiling,
|
||||||
# so we allow it to be empty
|
# so we allow it to be empty
|
||||||
|
|||||||
@ -185,6 +185,10 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
if not isinstance(pixel_values, torch.Tensor):
|
if not isinstance(pixel_values, torch.Tensor):
|
||||||
raise ValueError("Incorrect type of pixel values. "
|
raise ValueError("Incorrect type of pixel values. "
|
||||||
f"Got type: {type(pixel_values)}")
|
f"Got type: {type(pixel_values)}")
|
||||||
|
|
||||||
|
# Remove the N dimension until multiple images are supported.
|
||||||
|
pixel_values = pixel_values.squeeze(1)
|
||||||
|
|
||||||
return PaliGemmaImagePixelInputs(
|
return PaliGemmaImagePixelInputs(
|
||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
data=self._validate_pixel_values(pixel_values),
|
data=self._validate_pixel_values(pixel_values),
|
||||||
@ -194,6 +198,10 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
if not isinstance(image_embeds, torch.Tensor):
|
if not isinstance(image_embeds, torch.Tensor):
|
||||||
raise ValueError("Incorrect type of image embeddings. "
|
raise ValueError("Incorrect type of image embeddings. "
|
||||||
f"Got type: {type(image_embeds)}")
|
f"Got type: {type(image_embeds)}")
|
||||||
|
|
||||||
|
# Remove the N dimension until multiple images are supported.
|
||||||
|
image_embeds = image_embeds.squeeze(1)
|
||||||
|
|
||||||
return PaliGemmaImageEmbeddingInputs(
|
return PaliGemmaImageEmbeddingInputs(
|
||||||
type="image_embeds",
|
type="image_embeds",
|
||||||
data=image_embeds,
|
data=image_embeds,
|
||||||
|
|||||||
@ -560,6 +560,14 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
|
|||||||
raise ValueError("Incorrect type of image sizes. "
|
raise ValueError("Incorrect type of image sizes. "
|
||||||
f"Got type: {type(image_sizes)}")
|
f"Got type: {type(image_sizes)}")
|
||||||
|
|
||||||
|
# Merge the B and N dimensions.
|
||||||
|
if isinstance(pixel_values, torch.Tensor):
|
||||||
|
pixel_values = pixel_values.flatten(0, 1)
|
||||||
|
else:
|
||||||
|
pixel_values = torch.cat(pixel_values)
|
||||||
|
|
||||||
|
image_sizes = image_sizes.flatten(0, 1)
|
||||||
|
|
||||||
return Phi3VImagePixelInputs(
|
return Phi3VImagePixelInputs(
|
||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
data=self._validate_pixel_values(pixel_values),
|
data=self._validate_pixel_values(pixel_values),
|
||||||
|
|||||||
@ -333,6 +333,12 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
|
|||||||
raise ValueError("Incorrect type of audio features. "
|
raise ValueError("Incorrect type of audio features. "
|
||||||
f"Got type: {type(audio_features)}")
|
f"Got type: {type(audio_features)}")
|
||||||
|
|
||||||
|
# Remove the N dimension until multiple audios are supported.
|
||||||
|
if isinstance(audio_features, torch.Tensor):
|
||||||
|
audio_features = audio_features.squeeze(1)
|
||||||
|
else:
|
||||||
|
audio_features = [t.squeeze(0) for t in audio_features]
|
||||||
|
|
||||||
return UltravoxAudioFeatureInputs(type="audio_features",
|
return UltravoxAudioFeatureInputs(type="audio_features",
|
||||||
data=audio_features)
|
data=audio_features)
|
||||||
|
|
||||||
@ -341,6 +347,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
|
|||||||
raise ValueError("Incorrect type of audio embeds. "
|
raise ValueError("Incorrect type of audio embeds. "
|
||||||
f"Got type: {type(audio_embeds)}")
|
f"Got type: {type(audio_embeds)}")
|
||||||
|
|
||||||
|
# Remove the N dimension until multiple audios are supported.
|
||||||
|
audio_embeds = audio_embeds.squeeze(1)
|
||||||
|
|
||||||
return UltravoxAudioEmbeddingInputs(type="audio_embeds",
|
return UltravoxAudioEmbeddingInputs(type="audio_embeds",
|
||||||
data=audio_embeds)
|
data=audio_embeds)
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
from typing import Dict, Iterable, List, Optional, Protocol, Tuple
|
from typing import Dict, Iterable, List, Optional, Protocol, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.func import functional_call
|
from torch.func import functional_call
|
||||||
@ -10,7 +11,7 @@ from vllm.config import (CacheConfig, LoRAConfig, MultiModalConfig,
|
|||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.model_loader.loader import build_model
|
from vllm.model_executor.model_loader.loader import build_model
|
||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
from vllm.multimodal import BatchedTensors
|
from vllm.multimodal.base import NestedTensors
|
||||||
from vllm.utils import is_pin_memory_available
|
from vllm.utils import is_pin_memory_available
|
||||||
|
|
||||||
|
|
||||||
@ -54,9 +55,34 @@ def init_vllm_registered_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Recursively concatenates NestedTensors along any heterogeneously sized
|
||||||
|
dimensions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(embeddings, torch.Tensor):
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))
|
||||||
|
|
||||||
|
|
||||||
|
def _embedding_count_expression(embeddings: NestedTensors) -> str:
|
||||||
|
"""
|
||||||
|
Constructs a debugging representation of the number of embeddings in the
|
||||||
|
NestedTensors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(embeddings, torch.Tensor):
|
||||||
|
return " x ".join([str(dim) for dim in embeddings.shape[:-1]])
|
||||||
|
|
||||||
|
return " + ".join(
|
||||||
|
_embedding_count_expression(inner) for inner in embeddings)
|
||||||
|
|
||||||
|
|
||||||
def merge_multimodal_embeddings(input_ids: torch.Tensor,
|
def merge_multimodal_embeddings(input_ids: torch.Tensor,
|
||||||
inputs_embeds: torch.Tensor,
|
inputs_embeds: torch.Tensor,
|
||||||
multimodal_embeddings: BatchedTensors,
|
multimodal_embeddings: NestedTensors,
|
||||||
placeholder_token_id: int) -> torch.Tensor:
|
placeholder_token_id: int) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
|
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
|
||||||
@ -69,28 +95,16 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor,
|
|||||||
mask = (input_ids == placeholder_token_id)
|
mask = (input_ids == placeholder_token_id)
|
||||||
num_expected_tokens = mask.sum()
|
num_expected_tokens = mask.sum()
|
||||||
|
|
||||||
if isinstance(multimodal_embeddings, torch.Tensor):
|
flattened = _flatten_embeddings(multimodal_embeddings)
|
||||||
batch_size, batch_tokens, *_, embed_dim = multimodal_embeddings.shape
|
*dims, embed_dim = flattened.shape
|
||||||
total_tokens = batch_size * batch_tokens
|
num_multimodal_embeddings = np.prod(dims)
|
||||||
if num_expected_tokens != total_tokens:
|
if num_multimodal_embeddings != num_expected_tokens:
|
||||||
expr = f"{batch_size} x {batch_tokens}"
|
expr = _embedding_count_expression(multimodal_embeddings)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Attempted to assign {expr} = {total_tokens} "
|
f"Attempted to assign {expr} = {num_multimodal_embeddings} "
|
||||||
f"multimodal tokens to {num_expected_tokens} placeholders")
|
f"multimodal tokens to {num_expected_tokens} placeholders")
|
||||||
|
|
||||||
inputs_embeds[mask] = multimodal_embeddings.view(
|
|
||||||
total_tokens, embed_dim)
|
|
||||||
else:
|
|
||||||
size_per_batch = [t.shape[0] for t in multimodal_embeddings]
|
|
||||||
total_tokens = sum(size_per_batch)
|
|
||||||
if num_expected_tokens != total_tokens:
|
|
||||||
expr = ' + '.join(map(str, size_per_batch))
|
|
||||||
raise ValueError(
|
|
||||||
f"Attempted to assign {expr} = {total_tokens} "
|
|
||||||
f"multimodal tokens to {num_expected_tokens} placeholders")
|
|
||||||
|
|
||||||
inputs_embeds[mask] = torch.cat(multimodal_embeddings)
|
|
||||||
|
|
||||||
|
inputs_embeds[mask] = flattened.view(num_expected_tokens, embed_dim)
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from .base import (BatchedTensorInputs, BatchedTensors, MultiModalDataBuiltins,
|
from .base import (BatchedTensorInputs, MultiModalDataBuiltins,
|
||||||
MultiModalDataDict, MultiModalInputs, MultiModalPlugin,
|
MultiModalDataDict, MultiModalInputs, MultiModalPlugin,
|
||||||
NestedTensors)
|
NestedTensors)
|
||||||
from .registry import MultiModalRegistry
|
from .registry import MultiModalRegistry
|
||||||
@ -14,7 +14,6 @@ See also:
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BatchedTensorInputs",
|
"BatchedTensorInputs",
|
||||||
"BatchedTensors",
|
|
||||||
"MultiModalDataBuiltins",
|
"MultiModalDataBuiltins",
|
||||||
"MultiModalDataDict",
|
"MultiModalDataDict",
|
||||||
"MultiModalInputs",
|
"MultiModalInputs",
|
||||||
|
|||||||
@ -1,9 +1,8 @@
|
|||||||
import sys
|
import sys
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections import UserDict, defaultdict
|
from collections import UserDict, defaultdict
|
||||||
from typing import Callable, Dict, List, Mapping, Optional
|
from typing import (Callable, Dict, List, Mapping, Optional, Tuple, Type,
|
||||||
from typing import Sequence as GenericSequence
|
TypedDict, TypeVar, Union, cast, final)
|
||||||
from typing import Tuple, Type, TypedDict, TypeVar, Union, cast, final
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -15,23 +14,16 @@ from typing_extensions import TypeAlias
|
|||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.inputs import InputContext
|
from vllm.inputs import InputContext
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import JSONTree, json_map_leaves
|
from vllm.utils import json_map_leaves
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
NestedTensors = Union[GenericSequence[torch.Tensor], torch.Tensor]
|
NestedTensors = Union[List["NestedTensors"], torch.Tensor]
|
||||||
"""
|
"""
|
||||||
Use a list instead of a tensor if the dimensions of each element do not match.
|
Uses a list instead of a tensor if the dimensions of each element do not match.
|
||||||
Currently only supports up to singly nested list of tensors.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
BatchedTensors: TypeAlias = JSONTree[torch.Tensor]
|
BatchedTensorInputs: TypeAlias = Dict[str, NestedTensors]
|
||||||
"""
|
|
||||||
A nested JSON structure of tensors which have been batched via
|
|
||||||
:meth:`MultiModalInputs.batch`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
BatchedTensorInputs: TypeAlias = Dict[str, JSONTree[torch.Tensor]]
|
|
||||||
"""
|
"""
|
||||||
A dictionary containing nested tensors which have been batched via
|
A dictionary containing nested tensors which have been batched via
|
||||||
:meth:`MultiModalInputs.batch`.
|
:meth:`MultiModalInputs.batch`.
|
||||||
@ -54,26 +46,23 @@ class MultiModalInputs(_MultiModalInputsBase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _try_concat(tensors: List[NestedTensors]) -> BatchedTensors:
|
def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
|
||||||
"""
|
"""
|
||||||
If each input tensor in the batch has the same shape, return a single
|
Recursively stacks lists of tensors when they all have the same shape.
|
||||||
batched tensor; otherwise, return a list of :class:`NestedTensors` with
|
|
||||||
one element per item in the batch.
|
|
||||||
"""
|
"""
|
||||||
# may be list rather than tensors
|
if isinstance(nested_tensors, torch.Tensor):
|
||||||
if isinstance(tensors[0], list):
|
return nested_tensors
|
||||||
return [[t for t in tensor[0]]
|
|
||||||
for tensor in cast(List[List[torch.Tensor]], tensors)]
|
|
||||||
|
|
||||||
tensors_ = cast(List[torch.Tensor], tensors)
|
stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors]
|
||||||
|
if any(isinstance(t, list) for t in stacked):
|
||||||
|
return stacked
|
||||||
|
|
||||||
unbatched_shape = tensors_[0].shape[1:]
|
tensors_ = cast(List[torch.Tensor], stacked)
|
||||||
|
if any(t.shape != tensors_[0].shape for t in tensors_):
|
||||||
|
# The tensors have incompatible shapes and can't be stacked.
|
||||||
|
return tensors_
|
||||||
|
|
||||||
for tensor in tensors_:
|
return torch.stack(tensors_)
|
||||||
if tensor.shape[1:] != unbatched_shape:
|
|
||||||
return [tensor.squeeze(0) for tensor in tensors_]
|
|
||||||
|
|
||||||
return torch.cat(tensors_, dim=0)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def batch(inputs_list: List["MultiModalInputs"]) -> BatchedTensorInputs:
|
def batch(inputs_list: List["MultiModalInputs"]) -> BatchedTensorInputs:
|
||||||
@ -102,7 +91,7 @@ class MultiModalInputs(_MultiModalInputsBase):
|
|||||||
item_lists[k].append(v)
|
item_lists[k].append(v)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
k: MultiModalInputs._try_concat(item_list)
|
k: MultiModalInputs._try_stack(item_list)
|
||||||
for k, item_list in item_lists.items()
|
for k, item_list in item_lists.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user