[VLM] Abstract out multi-modal data parsing in merged processor (#11620)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2024-12-30 23:01:35 +08:00 committed by GitHub
parent b12e87f942
commit 8d9b6721e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 559 additions and 311 deletions

View File

@ -356,7 +356,7 @@ steps:
- pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model'
- pytest -v -s models/embedding/language -m 'not core_model'
- label: Multi-Modal Models Test (Standard) # 28min
- label: Multi-Modal Models Test (Standard) # 40min
#mirror_hardwares: [amd]
source_file_dependencies:
- vllm/
@ -372,7 +372,7 @@ steps:
- pytest -v -s models/encoder_decoder/language -m core_model
- pytest -v -s models/encoder_decoder/vision_language -m core_model
- label: Multi-Modal Models Test (Extended) 1 # 1h16m
- label: Multi-Modal Models Test (Extended) 1 # 48m
optional: true
source_file_dependencies:
- vllm/

View File

@ -33,7 +33,7 @@ from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalData, MultiModalKwargs,
from vllm.multimodal.inputs import (ModalityData, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
@ -54,7 +54,7 @@ def calculate_image_placeholder(vision_config):
def mm_input_mapper_for_glmv(
ctx: InputContext,
data: MultiModalData[object],
data: ModalityData[object],
) -> Dict:
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(

View File

@ -20,11 +20,13 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems,
MultiModalFieldConfig, MultiModalInputsV2,
MultiModalKwargs, NestedTensors)
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import ImageProcessorItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
ProcessorInputs, PromptReplacement,
MultiModalDataItems, ProcessorInputs,
PromptReplacement,
full_groupby_modality)
from vllm.sequence import IntermediateTensors
@ -179,7 +181,9 @@ class LlavaMultiModalProcessor(BaseMultiModalProcessor):
assert isinstance(vision_config, PixtralVisionConfig)
def get_replacement_pixtral(item_idx: int):
image_size = mm_items.get_image_size(item_idx)
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
(
num_width_tokens,
num_height_tokens,
@ -591,8 +595,8 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
mm_items = self._get_mm_items(mm_data)
mm_item_counts = mm_items.get_item_counts()
mm_items = self._to_mm_items(mm_data)
mm_item_counts = mm_items.get_all_counts()
mm_kwargs = result["mm_kwargs"]
# We reimplement the functionality of MLlavaProcessor from

View File

@ -32,12 +32,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems,
MultiModalFieldConfig, MultiModalInputsV2,
MultiModalKwargs, NestedTensors,
PlaceholderRange)
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import ImageProcessorItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
ProcessorInputs, PromptReplacement,
MultiModalDataItems, ProcessorInputs,
PromptReplacement,
_BoundPromptReplacement,
_PlaceholderInfo)
from vllm.sequence import IntermediateTensors
@ -381,7 +382,9 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
assert isinstance(bos_token_id, int)
def get_replacement_phi3v(item_idx: int):
image_size = mm_items.get_image_size(item_idx)
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
num_tokens = image_processor.calc_num_image_tokens_from_image_size(
width=image_size.width,
height=image_size.height,
@ -389,12 +392,14 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
return [_IMAGE_TOKEN_ID] * num_tokens + [bos_token_id]
num_images = mm_items.get_count("image", strict=False)
return [
PromptReplacement(
modality="image",
target=image_token,
replacement=get_replacement_phi3v,
) for image_token in image_tokens[:len(mm_items.images)]
) for image_token in image_tokens[:num_images]
]
def _apply_prompt_replacements(

View File

@ -20,8 +20,8 @@
# limitations under the License.
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
from functools import cached_property
from typing import (Any, Iterable, List, Mapping, Optional, Set, Tuple,
TypedDict, Union)
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
Union)
import numpy as np
import torch
@ -38,10 +38,12 @@ from vllm.inputs import InputContext
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataItems, MultiModalFieldConfig,
MultiModalKwargs, NestedTensors)
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor,
ProcessorInputs, PromptReplacement)
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal, SupportsPP
@ -99,15 +101,9 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
def _get_feature_extractor(self) -> WhisperFeatureExtractor:
return self._get_hf_processor().feature_extractor # type: ignore
def _get_hf_mm_data(
self,
mm_items: MultiModalDataItems,
) -> tuple[dict[str, Any], dict[str, Any]]:
# resample audio to the model's sampling rate
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self._get_feature_extractor()
mm_items.resample_audios(feature_extractor.sampling_rate)
return super()._get_hf_mm_data(mm_items)
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
def _call_hf_processor(
self,

View File

@ -25,7 +25,6 @@ from functools import cached_property, partial
from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
Set, Tuple, Type, TypedDict, Union)
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -55,15 +54,16 @@ from vllm.model_executor.layers.quantization.gptq_marlin import (
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems,
from vllm.multimodal.inputs import (ImageItem, ModalityData,
MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
NestedTensors, VideoItem)
from vllm.multimodal.parse import ModalityDataItems, MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor,
ProcessorInputs, PromptReplacement)
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope
from vllm.utils import is_list_of
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, get_vit_attn_backend,
@ -719,61 +719,81 @@ get_max_qwen2_vl_video_tokens = partial(get_max_qwen2_vl_mm_tokens,
data_type_key="video")
class Qwen2VLMultiModalDataItems(MultiModalDataItems):
class Qwen2EmbeddingItems(ModalityDataItems[dict[str, torch.Tensor],
dict[str, torch.Tensor]]):
@staticmethod
def from_dict(data: MultiModalDataDict) -> "MultiModalDataItems":
"""
Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`.
"""
multi_data = Qwen2VLMultiModalDataItems()
def __init__(self, data: dict, modality: str) -> None:
super().__init__(data)
for k, v in data.items():
# TODO: Make a separate modality for embedding inputs
# to avoid confusion
# yapf: disable
if k == "video":
# Special case since even a single item can be a list
multi_data[k] = ( # type: ignore[index]
v if (
isinstance(v, (dict, torch.Tensor)) # type: ignore[assignment]
or is_list_of(v, list)
or isinstance(v[0], (np.ndarray, torch.Tensor))
and v[0].ndim == 4
) else [v]
)
elif k in ("image", "audio"):
multi_data[k] = ( # type: ignore[index]
v if isinstance(v, (dict, torch.Tensor, list)) else [v]
)
else:
multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index]
# yapf: enable
self.modality = modality
return multi_data
grid_thw = data[f"{modality}_grid_thw"]
slice_idxs = [0] + grid_thw.prod(-1).cumsum_(0).tolist()
self._slices = [
slice(slice_idxs[i], slice_idxs[i + 1])
for i in range(len(grid_thw))
]
def get_item_counts(self) -> Mapping[str, int]:
return {
m: (
len(items[f"{m}_grid_thw"]) # type: ignore
if isinstance(items, dict) else len(items))
for m, items in self.items()
}
def __repr__(self) -> str:
return (f"{type(self).__name__}(modality={self.modality!r})")
def has_embedding_inputs(self) -> bool:
return any(
isinstance(items, dict) or any(
isinstance(item, torch.Tensor) for item in items)
for items in self.values())
def get_count(self) -> int:
return len(self.data[f"{self.modality}_grid_thw"])
def get(self, index: int) -> dict[str, torch.Tensor]:
out = {}
for k, v in self.data.items():
if v != f"{self.modality}_grid_thw":
v = v[self._slices[index]]
out[k] = v
return out
def get_processor_data(self) -> Mapping[str, object]:
return {}
def get_passthrough_data(self) -> Mapping[str, object]:
return self.data
class Qwen2ImageEmbeddingItems(Qwen2EmbeddingItems):
def __init__(self, data: dict) -> None:
super().__init__(data, "image")
class Qwen2VideoEmbeddingItems(Qwen2EmbeddingItems):
def __init__(self, data: dict) -> None:
super().__init__(data, "video")
class Qwen2MultiModalDataParser(MultiModalDataParser):
def _parse_image_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
) -> ModalityDataItems[Any, Any]:
if isinstance(data, dict):
return Qwen2EmbeddingItems(data, modality="image")
return super()._parse_image_data(data)
def _parse_video_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
) -> ModalityDataItems[Any, Any]:
if isinstance(data, dict):
return Qwen2EmbeddingItems(data, modality="video")
return super()._parse_video_data(data)
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
def _get_mm_items(
self,
mm_data: MultiModalDataDict,
) -> MultiModalDataItems:
return Qwen2VLMultiModalDataItems.from_dict(mm_data)
def _get_data_parser(self) -> MultiModalDataParser:
return Qwen2MultiModalDataParser()
def _get_hf_processor(
self,
@ -796,35 +816,6 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
return hf_processor
def _get_hf_mm_data(
self,
mm_items: MultiModalDataItems,
) -> tuple[dict[str, Any], dict[str, Any]]:
processor_data = dict[str, Any]()
passthrough_data = dict[str, Any]()
for k, v in mm_items.items():
# TODO: Make a separate modality for embedding inputs
# to avoid confusion
if k in ("image", "video", "audio"):
if isinstance(v, dict):
# Pass through embedding inputs (dict)
passthrough_data.update(v)
elif isinstance(v, torch.Tensor) and v.ndim == 3:
# Pass through embedding inputs (single)
passthrough_data[f"{k}_embeds"] = [v]
elif (is_list_of(v, torch.Tensor) and len(v) > 0
and v[0].ndim == 2):
# Pass through embedding inputs (multi)
passthrough_data[f"{k}_embeds"] = v
elif len(v) > 0:
# Map keys to plural form, e.g.: image -> images
processor_data[f"{k}s"] = v
else:
processor_data[k] = v
return processor_data, passthrough_data
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,

View File

@ -3,8 +3,8 @@
import math
from functools import cached_property, lru_cache
from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set,
Tuple, TypedDict, Union)
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union)
import numpy as np
import torch
@ -24,10 +24,12 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataItems, MultiModalFieldConfig,
MultiModalKwargs, NestedTensors)
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor,
ProcessorInputs, PromptReplacement)
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from vllm.utils import is_list_of
@ -85,15 +87,9 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
hf_processor = self._get_hf_processor()
return hf_processor.audio_processor.feature_extractor # type: ignore
def _get_hf_mm_data(
self,
mm_items: MultiModalDataItems,
) -> tuple[dict[str, Any], dict[str, Any]]:
# resample audio to the model's sampling rate
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self._get_feature_extractor()
mm_items.resample_audios(feature_extractor.sampling_rate)
return super()._get_hf_mm_data(mm_items)
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
def _call_hf_processor(
self,

View File

@ -1,8 +1,7 @@
from .base import MultiModalPlaceholderMap, MultiModalPlugin
from .inputs import (BatchedTensorInputs, MultiModalData,
MultiModalDataBuiltins, MultiModalDataDict,
MultiModalKwargs, MultiModalPlaceholderDict,
NestedTensors)
from .inputs import (BatchedTensorInputs, ModalityData, MultiModalDataBuiltins,
MultiModalDataDict, MultiModalKwargs,
MultiModalPlaceholderDict, NestedTensors)
from .registry import MultiModalRegistry
MULTIMODAL_REGISTRY = MultiModalRegistry()
@ -16,7 +15,7 @@ See also:
__all__ = [
"BatchedTensorInputs",
"MultiModalData",
"ModalityData",
"MultiModalDataBuiltins",
"MultiModalDataDict",
"MultiModalKwargs",

View File

@ -9,7 +9,7 @@ from vllm.inputs.registry import InputContext
from vllm.utils import PlaceholderModule
from .base import MediaIO, MultiModalPlugin
from .inputs import AudioItem, MultiModalData, MultiModalKwargs
from .inputs import AudioItem, ModalityData, MultiModalKwargs
try:
import librosa
@ -31,7 +31,7 @@ class AudioPlugin(MultiModalPlugin):
def _default_input_mapper(
self,
ctx: InputContext,
data: MultiModalData[AudioItem],
data: ModalityData[AudioItem],
**mm_processor_kwargs,
) -> MultiModalKwargs:
raise NotImplementedError("There is no default audio input mapper")

View File

@ -15,12 +15,12 @@ if TYPE_CHECKING:
from vllm.config import ModelConfig
from vllm.sequence import SequenceGroupMetadata
from .inputs import (MultiModalData, MultiModalDataDict, MultiModalKwargs,
from .inputs import (ModalityData, MultiModalDataDict, MultiModalKwargs,
PlaceholderRange)
logger = init_logger(__name__)
MultiModalInputMapper = Callable[[InputContext, MultiModalData[object]],
MultiModalInputMapper = Callable[[InputContext, ModalityData[object]],
MultiModalKwargs]
"""
Return a dictionary to be passed as keyword arguments to
@ -69,7 +69,7 @@ class MultiModalPlugin(ABC):
def _default_input_mapper(
self,
ctx: InputContext,
data: MultiModalData[Any],
data: ModalityData[Any],
**mm_processor_kwargs,
) -> MultiModalKwargs:
"""
@ -118,7 +118,7 @@ class MultiModalPlugin(ABC):
def map_input(
self,
model_config: "ModelConfig",
data: MultiModalData[Any],
data: ModalityData[Any],
mm_processor_kwargs: Optional[dict[str, Any]],
) -> MultiModalKwargs:
"""

View File

@ -13,7 +13,7 @@ from vllm.transformers_utils.processor import get_image_processor
from vllm.utils import is_list_of
from .base import MediaIO, MultiModalPlugin
from .inputs import ImageItem, MultiModalData, MultiModalKwargs
from .inputs import ImageItem, ModalityData, MultiModalKwargs
if TYPE_CHECKING:
from vllm.config import ModelConfig
@ -44,7 +44,7 @@ class ImagePlugin(MultiModalPlugin):
def _default_input_mapper(
self,
ctx: InputContext,
data: MultiModalData[ImageItem],
data: ModalityData[ImageItem],
**mm_processor_kwargs,
) -> MultiModalKwargs:
model_config = ctx.model_config

View File

@ -2,53 +2,74 @@ from abc import ABC, abstractmethod
from collections import UserDict, defaultdict
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from typing import (Any, Literal, NamedTuple, TypedDict, TypeVar, Union, cast,
final)
from typing import Any, Literal, TypedDict, TypeVar, Union, cast, final
import numpy as np
import torch
import torch.types
from PIL.Image import Image
from transformers import BatchFeature
from typing_extensions import NotRequired, TypeAlias, assert_never
from typing_extensions import NotRequired, TypeAlias
from vllm.utils import JSONTree, is_list_of, json_map_leaves
_T = TypeVar("_T")
# yapf: disable
ImageItem: TypeAlias = Union[Image, np.ndarray, torch.Tensor]
HfImageItem: TypeAlias = Union[Image, np.ndarray, torch.Tensor]
"""
A :class:`transformers.image_utils.ImageInput` representing a single image
item, which can be passed to a HuggingFace :code:`ImageProcessor`.
"""
VideoItem: TypeAlias = Union[
list[Image],
np.ndarray,
torch.Tensor,
list[np.ndarray],
list[torch.Tensor],
]
HfVideoItem: TypeAlias = Union[list[Image], np.ndarray, torch.Tensor,
list[np.ndarray], list[torch.Tensor]]
"""
A :class:`transformers.image_utils.VideoInput` representing a single video
item, which can be passed to a HuggingFace :code:`VideoProcessor`.
"""
AudioItem: TypeAlias = Union[
np.ndarray,
list[float],
# `(audio, sampling_rate)`: If the audio's sampling rate is different
# from that expected by the model, we need to resample it.
tuple[np.ndarray, float],
]
HfAudioItem: TypeAlias = Union[list[float], np.ndarray, torch.Tensor]
"""
Represents a single audio
item, which can be passed to a HuggingFace :code:`AudioProcessor`.
"""
# yapf: enable
MultiModalData: TypeAlias = Union[_T, list[_T]]
ImageItem: TypeAlias = Union[HfImageItem, torch.Tensor]
"""
A :class:`transformers.image_utils.ImageInput` representing a single image
item, which can be passed to a HuggingFace :code:`ImageProcessor`.
Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as image embeddings;
these are directly passed to the model without HF processing.
"""
VideoItem: TypeAlias = Union[HfVideoItem, torch.Tensor]
"""
A :class:`transformers.image_utils.VideoInput` representing a single video
item, which can be passed to a HuggingFace :code:`VideoProcessor`.
Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as video embeddings;
these are directly passed to the model without HF processing.
"""
AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float],
torch.Tensor]
"""
Represents a single audio
item, which can be passed to a HuggingFace :code:`AudioProcessor`.
Alternatively, a tuple `(audio, sampling_rate)`, where the sampling rate
is different from that expected by the model;
these are resampled to the model's sampling rate before being processed by HF.
Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as audio embeddings;
these are directly passed to the model without HF processing.
"""
ModalityData: TypeAlias = Union[_T, list[_T]]
"""
Either a single data item, or a list of data items.
@ -61,17 +82,17 @@ The number of data items allowed per modality is restricted by
class MultiModalDataBuiltins(TypedDict, total=False):
"""Type annotations for modality types predefined by vLLM."""
image: MultiModalData[ImageItem]
image: ModalityData[ImageItem]
"""The input image(s)."""
video: MultiModalData[VideoItem]
video: ModalityData[VideoItem]
"""The input video(s)."""
audio: MultiModalData[AudioItem]
audio: ModalityData[AudioItem]
"""The input audio(s)."""
MultiModalDataDict: TypeAlias = Mapping[str, MultiModalData[Any]]
MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]]
"""
A dictionary containing an entry for each modality type to input.
@ -83,123 +104,6 @@ Note:
"""
class ImageSize(NamedTuple):
width: int
height: int
class MultiModalDataItems(UserDict[str, list[Any]]):
"""
As :class:`MultiModalDataDict`, but normalized such that each entry
corresponds to a list.
"""
@staticmethod
def from_dict(data: MultiModalDataDict) -> "MultiModalDataItems":
"""
Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`.
"""
multi_data = MultiModalDataItems()
for k, v in data.items():
# TODO: Make a separate modality for embedding inputs
# to avoid confusion
# yapf: disable
if k == "video":
# Special case since even a single item can be a list
multi_data[k] = ( # type: ignore[index]
v if (
isinstance(v, torch.Tensor)
or is_list_of(v, list)
or isinstance(v[0], (np.ndarray, torch.Tensor))
and v[0].ndim == 4
) else [v]
)
elif k in ("image", "audio"):
multi_data[k] = ( # type: ignore[index]
v if isinstance(v, (torch.Tensor, list)) else [v]
)
else:
multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index]
# yapf: enable
return multi_data
# NOTE: When a field (e.g. `images`) doesn't exist, directly appending to
# `self.images` doesn't update this dictionary, which may be confusing
# We annotate the getter methods as `Sequence` to prevent others from
# trying to update the list in this way
@property
def images(self) -> Sequence[ImageItem]:
return self.get("image", [])
@property
def videos(self) -> Sequence[VideoItem]:
return self.get("video", [])
@property
def audios(self) -> Sequence[AudioItem]:
return self.get("audio", [])
def get_item_counts(self) -> Mapping[str, int]:
return {m: len(items) for m, items in self.items()}
def has_embedding_inputs(self) -> bool:
return any(
any(isinstance(item, torch.Tensor) for item in items)
for items in self.values())
def get_image_size(self, item_idx: int) -> ImageSize:
image = self.images[item_idx]
if isinstance(image, Image):
return ImageSize(*image.size)
if isinstance(image, (np.ndarray, torch.Tensor)):
_, h, w = image.shape
return ImageSize(w, h)
assert_never(image)
def get_audio_with_sr(
self,
item_idx: int,
*,
default_sr: float,
) -> tuple[np.ndarray, float]:
audio = self.audios[item_idx]
if isinstance(audio, tuple):
return audio
if isinstance(audio, list):
return np.array(audio), default_sr
if isinstance(audio, np.ndarray):
return audio, default_sr
assert_never(audio)
def resample_audios(self, new_sr: float, *, drop_sr: bool = True) -> None:
"""
If :code:`drop_sr=True`, the audio items in this dictionary are updated
to be NumPy arrays which implicitly means that their sampling rate is
the same as the model's expected sampling rate; otherwise, they remain
as :code:`(audio, new_sr)` tuples.
"""
# Avoid circular import
from .audio import resample_audio
if not self.audios:
return
new_audios = []
for item_idx in range(len(self.audios)):
audio, sr = self.get_audio_with_sr(item_idx, default_sr=new_sr)
audio = resample_audio(audio, orig_sr=sr, target_sr=new_sr)
new_audios.append(audio if drop_sr else (audio, new_sr))
self["audio"] = new_audios
class PlaceholderRange(TypedDict):
"""
Placeholder location information for multi-modal data.
@ -436,7 +340,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
) -> "MultiModalKwargs":
data = {
key: items[0].field.reduce(items).data
for key, items in items_by_key.items()
for key, items in items_by_key.items() if len(items) > 0
}
return MultiModalKwargs(data,
@ -567,6 +471,11 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
Get the keyword arguments corresponding to an item identified by
its modality and index.
"""
if modality not in self._keys_by_modality:
available_modalities = set(self._keys_by_modality.keys())
raise KeyError(f"Modality {modality!r} not found. "
f"Available modalities: {available_modalities}")
keys_to_gather = self._keys_by_modality[modality]
return {

344
vllm/multimodal/parse.py Normal file
View File

@ -0,0 +1,344 @@
from abc import ABC, abstractmethod
from collections import UserDict
from collections.abc import Callable, Iterator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeVar
import numpy as np
import torch
from PIL.Image import Image
from typing_extensions import TypeAlias, TypeGuard, assert_never
from vllm.utils import is_list_of
from .audio import resample_audio
from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem,
ImageItem, ModalityData, MultiModalDataDict,
NestedTensors, VideoItem)
_T = TypeVar("_T")
_I = TypeVar("_I")
class ModalityDataItems(ABC, Generic[_T, _I]):
def __init__(self, data: _T) -> None:
super().__init__()
self.data = data
def __len__(self) -> int:
return self.get_count()
def __getitem__(self, index: int) -> _I:
return self.get(index)
if TYPE_CHECKING:
# Auto-generated
def __iter__(self) -> Iterator[_I]:
...
@abstractmethod
def get_count(self) -> int:
"""Get the number of data items."""
raise NotImplementedError
@abstractmethod
def get(self, index: int) -> _I:
"""Get a data item by its index."""
raise NotImplementedError
def get_all(self) -> list[_I]:
"""Get all data items."""
return [self.get(idx) for idx in range(self.get_count())]
@abstractmethod
def get_processor_data(self) -> Mapping[str, object]:
"""Get the data to pass to the HF processor."""
raise NotImplementedError
@abstractmethod
def get_passthrough_data(self) -> Mapping[str, object]:
"""Get the data to pass directly to the model."""
raise NotImplementedError
class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]):
def __init__(self, data: Sequence[_T], modality: str) -> None:
super().__init__(data)
self.modality = modality
def __repr__(self) -> str:
return (f"{type(self).__name__}(modality={self.modality!r})")
def get_count(self) -> int:
return len(self.data)
def get(self, index: int) -> _T:
return self.data[index]
def get_processor_data(self) -> Mapping[str, object]:
return {f"{self.modality}s": self.data}
def get_passthrough_data(self) -> Mapping[str, object]:
return {}
class EmbeddingItems(ModalityDataItems[NestedTensors, torch.Tensor]):
def __init__(self, data: NestedTensors, modality: str) -> None:
super().__init__(data)
self.modality = modality
def __repr__(self) -> str:
return (f"{type(self).__name__}(modality={self.modality!r})")
def get_count(self) -> int:
return len(self.data)
def get(self, index: int) -> object:
return self.data[index]
def get_processor_data(self) -> Mapping[str, object]:
return {}
def get_passthrough_data(self) -> Mapping[str, object]:
return {f"{self.modality}_embeds": self.data}
class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]):
def __init__(self, data: Sequence[HfAudioItem]) -> None:
super().__init__(data, "audio")
class AudioEmbeddingItems(EmbeddingItems):
def __init__(self, data: NestedTensors) -> None:
super().__init__(data, "audio")
class ImageSize(NamedTuple):
width: int
height: int
class ImageProcessorItems(ProcessorBatchItems[HfImageItem]):
def __init__(self, data: Sequence[HfImageItem]) -> None:
super().__init__(data, "image")
def get_image_size(self, item_idx: int) -> ImageSize:
image = self.get(item_idx)
if isinstance(image, Image):
return ImageSize(*image.size)
if isinstance(image, (np.ndarray, torch.Tensor)):
_, h, w = image.shape
return ImageSize(w, h)
assert_never(image)
class ImageEmbeddingItems(EmbeddingItems):
def __init__(self, data: NestedTensors) -> None:
super().__init__(data, "image")
class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]):
def __init__(self, data: Sequence[HfVideoItem]) -> None:
super().__init__(data, "video")
class VideoEmbeddingItems(EmbeddingItems):
def __init__(self, data: NestedTensors) -> None:
super().__init__(data, "video")
_D = TypeVar("_D", bound=ModalityDataItems[Any, Any])
class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]):
"""
As :class:`MultiModalDataDict`, but normalized such that each entry
corresponds to a list.
"""
def get_count(self, modality: str, *, strict: bool = True) -> int:
"""
Get the number of data items belonging to a modality.
If `strict=False`, return `0` instead of raising :exc:`KeyError`
even if the modality is not found.
"""
if modality not in self:
if strict:
available_modalities = set(self.keys())
raise KeyError(f"Modality {modality!r} not found. "
f"Available modalities: {available_modalities}")
return 0
return self[modality].get_count()
def get_all_counts(self) -> Mapping[str, int]:
"""Get the number of items belonging to each modality."""
return {m: items.get_count() for m, items in self.items()}
def get_items(
self,
modality: str,
typ: type[_D],
) -> _D:
"""
Get the data items belonging to a modality,
requiring that they belong to a certain type.
"""
if modality not in self:
available_modalities = set(self.keys())
raise KeyError(f"Modality {modality!r} not found. "
f"Available modalities: {available_modalities}")
items = self[modality]
if not isinstance(items, typ):
raise TypeError(f"Invalid type of data items for {modality=}. "
f"Expected type: {typ}, but "
f"found type: {type(items)}")
return items
ModalityDataParser: TypeAlias = Callable[[ModalityData[Any]],
ModalityDataItems[Any, Any]]
class MultiModalDataParser:
"""
Parses :class:`MultiModalDataDict` into :class:`MultiModalDataItems`.
"""
def __init__(self, *, target_sr: Optional[float] = None) -> None:
super().__init__()
self.target_sr = target_sr
def _is_embeddings(self, data: object) -> TypeGuard[NestedTensors]:
if isinstance(data, torch.Tensor):
return data.ndim == 3
if is_list_of(data, torch.Tensor):
return len(data) == 0 or data[0].ndim == 2
return False
def _get_audio_with_sr(
self,
audio: AudioItem,
) -> tuple[np.ndarray, Optional[float]]:
if isinstance(audio, tuple):
return audio
if isinstance(audio, list):
return np.array(audio), None
if isinstance(audio, np.ndarray):
return audio, None
if isinstance(audio, torch.Tensor):
return audio.numpy(), None
assert_never(audio)
def _parse_audio_data(
self,
data: ModalityData[AudioItem],
) -> ModalityDataItems[Any, Any]:
if self._is_embeddings(data):
return AudioEmbeddingItems(data)
if (is_list_of(data, float)
or isinstance(data,
(np.ndarray, torch.Tensor)) and data.ndim == 1
or isinstance(data, tuple)):
data_items = [data]
elif isinstance(data, (np.ndarray, torch.Tensor)):
data_items = [elem for elem in data]
else:
data_items = data
new_audios = list[np.ndarray]()
for data_item in data_items:
audio, orig_sr = self._get_audio_with_sr(data_item)
if orig_sr is None:
new_audio = audio
else:
target_sr = self.target_sr
if target_sr is None:
raise RuntimeError(
"Audio resampling is not supported when "
"`target_sr` is not provided")
new_audio = resample_audio(audio,
orig_sr=orig_sr,
target_sr=target_sr)
new_audios.append(new_audio)
return AudioProcessorItems(new_audios)
def _parse_image_data(
self,
data: ModalityData[ImageItem],
) -> ModalityDataItems[Any, Any]:
if self._is_embeddings(data):
return ImageEmbeddingItems(data)
if (isinstance(data, Image)
or isinstance(data,
(np.ndarray, torch.Tensor)) and data.ndim == 3):
data_items = [data]
elif isinstance(data, (np.ndarray, torch.Tensor)):
data_items = [elem for elem in data]
else:
data_items = data
return ImageProcessorItems(data_items)
def _parse_video_data(
self,
data: ModalityData[VideoItem],
) -> ModalityDataItems[Any, Any]:
if self._is_embeddings(data):
return VideoEmbeddingItems(data)
if (is_list_of(data, Image)
or isinstance(data,
(np.ndarray, torch.Tensor)) and data.ndim == 4):
data_items = [data]
elif isinstance(data, (np.ndarray, torch.Tensor)):
data_items = [elem for elem in data]
else:
data_items = data
return VideoProcessorItems(data_items)
def _get_subparsers(self) -> Mapping[str, ModalityDataParser]:
return {
"audio": self._parse_audio_data,
"image": self._parse_image_data,
"video": self._parse_video_data,
}
def parse_mm_data(self,
mm_data: MultiModalDataDict) -> MultiModalDataItems:
subparsers = self._get_subparsers()
mm_items = MultiModalDataItems()
for k, v in mm_data.items():
if k not in subparsers:
raise ValueError(f"Unsupported modality: {k}")
mm_items[k] = subparsers[k](v)
return mm_items

View File

@ -15,11 +15,12 @@ from transformers import BatchFeature, ProcessorMixin
from vllm.inputs import DummyData, InputProcessingContext
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import LRUCache, flatten_2d_lists, full_groupby, is_list_of
from vllm.utils import LRUCache, flatten_2d_lists, full_groupby
from .inputs import (MultiModalDataDict, MultiModalDataItems,
MultiModalFieldConfig, MultiModalFieldItem,
MultiModalInputsV2, MultiModalKwargs, PlaceholderRange)
from .inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalFieldItem, MultiModalInputsV2, MultiModalKwargs,
PlaceholderRange)
from .parse import MultiModalDataItems, MultiModalDataParser
logger = init_logger(__name__)
@ -621,6 +622,16 @@ class BaseMultiModalProcessor(ABC):
) -> MultiModalInputsV2:
return self.apply(prompt, mm_data, hf_processor_mm_kwargs)
def _get_data_parser(self) -> MultiModalDataParser:
"""
Construct a data parser to preprocess multi-modal data items
before passing them to :meth:`_get_hf_mm_data`.
You can support additional modalities by creating a subclass
of :class:`MultiModalDataParser` that has additional subparsers.
"""
return MultiModalDataParser()
def _get_hf_processor(self) -> ProcessorMixin:
"""
Subclasses can add keyword arguments to this method to accept
@ -631,11 +642,16 @@ class BaseMultiModalProcessor(ABC):
def _get_tokenizer(self) -> AnyTokenizer:
return self.ctx.tokenizer
def _get_mm_items(
def _to_mm_items(
self,
mm_data: MultiModalDataDict,
) -> MultiModalDataItems:
return MultiModalDataItems.from_dict(mm_data)
"""
Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`
before passing them to :meth:`_get_hf_mm_data`.
"""
parser = self._get_data_parser()
return parser.parse_mm_data(mm_data)
@abstractmethod
def _get_mm_fields_config(
@ -680,22 +696,9 @@ class BaseMultiModalProcessor(ABC):
processor_data = dict[str, Any]()
passthrough_data = dict[str, Any]()
for k, v in mm_items.items():
# TODO: Make a separate modality for embedding inputs
# to avoid confusion
if k in ("image", "video", "audio"):
if isinstance(v, torch.Tensor) and v.ndim == 3:
# Pass through embedding inputs (single)
passthrough_data[f"{k}_embeds"] = [v]
elif (is_list_of(v, torch.Tensor) and len(v) > 0
and v[0].ndim == 2):
# Pass through embedding inputs (multi)
passthrough_data[f"{k}_embeds"] = v
elif len(v) > 0:
# Map keys to plural form, e.g.: image -> images
processor_data[f"{k}s"] = v
else:
processor_data[k] = v
for items in mm_items.values():
processor_data.update(items.get_processor_data())
passthrough_data.update(items.get_passthrough_data())
return processor_data, passthrough_data
@ -756,7 +759,7 @@ class BaseMultiModalProcessor(ABC):
cached items; instead, we rely on our own prompt replacement logic
for the full text.
"""
mm_missing_counts = mm_missing_data_items.get_item_counts()
mm_missing_counts = mm_missing_data_items.get_all_counts()
prompt_ids, _ = self._apply_hf_processor(
prompt_text=prompt_text,
@ -789,7 +792,8 @@ class BaseMultiModalProcessor(ABC):
cache = self.cache
model_id = self.ctx.model_config.model
if cache is None or mm_data_items.has_embedding_inputs():
_, passthrough_data = self._get_hf_mm_data(mm_data_items)
if cache is None or passthrough_data:
return self._apply_hf_processor(
prompt_text=prompt_text,
mm_items=mm_data_items,
@ -812,7 +816,7 @@ class BaseMultiModalProcessor(ABC):
modality: [mm_data_items[modality][idx] for idx in idxs]
for modality, idxs in mm_missing_idxs.items()
}
mm_missing_data_items = self._get_mm_items(mm_missing_data)
mm_missing_data_items = self._to_mm_items(mm_missing_data)
prompt_ids, mm_missing_kwargs = self._apply_hf_processor_missing(
prompt_text=prompt_text,
@ -852,7 +856,7 @@ class BaseMultiModalProcessor(ABC):
mm_merged_field_items[modality] = merged_modal_items_lst
if self.enable_sanity_checks:
mm_missing_counts = mm_missing_data_items.get_item_counts()
mm_missing_counts = mm_missing_data_items.get_all_counts()
assert all(
item_count == mm_missing_counts[modality]
for modality, item_count in mm_missing_next_idx.items()), dict(
@ -865,7 +869,7 @@ class BaseMultiModalProcessor(ABC):
)
if self.enable_sanity_checks:
mm_item_counts = mm_data_items.get_item_counts()
mm_item_counts = mm_data_items.get_all_counts()
for modality, item_count in mm_item_counts.items():
for item_idx in range(item_count):
@ -958,7 +962,7 @@ class BaseMultiModalProcessor(ABC):
3. Extract information about the placeholder tokens from the
processed token IDs.
"""
mm_items = self._get_mm_items(mm_data)
mm_items = self._to_mm_items(mm_data)
prompt_ids, mm_kwargs = self._cached_apply_hf_processor(
prompt_text,
@ -975,7 +979,7 @@ class BaseMultiModalProcessor(ABC):
# If HF processor already inserts placeholder tokens,
# there is no need for us to insert them
mm_item_counts = mm_items.get_item_counts()
mm_item_counts = mm_items.get_all_counts()
all_placeholders = self._find_placeholders(prompt_repls, prompt_ids,
mm_item_counts)

View File

@ -15,7 +15,7 @@ from vllm.transformers_utils.processor import get_video_processor
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import PlaceholderModule, is_list_of
from .base import MediaIO, MultiModalData
from .base import MediaIO, ModalityData
from .image import ImageMediaIO, ImagePlugin
from .inputs import MultiModalKwargs, VideoItem
@ -54,7 +54,7 @@ class VideoPlugin(ImagePlugin):
def _default_input_mapper(
self,
ctx: InputContext,
data: MultiModalData[VideoItem],
data: ModalityData[VideoItem],
**mm_processor_kwargs,
) -> MultiModalKwargs:
model_config = ctx.model_config