mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 12:05:48 +08:00
460 lines
14 KiB
Python
460 lines
14 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
from abc import ABC, abstractmethod
|
|
from collections import UserDict
|
|
from collections.abc import Callable, Iterator, Mapping, Sequence
|
|
from typing import (TYPE_CHECKING, Any, Generic, Literal, NamedTuple, Optional,
|
|
TypeVar, Union)
|
|
|
|
import numpy as np
|
|
import torch
|
|
from typing_extensions import TypeAlias, TypeGuard, assert_never
|
|
|
|
from vllm.utils import LazyLoader, is_list_of
|
|
|
|
from .audio import AudioResampler
|
|
from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem,
|
|
ImageItem, ModalityData, MultiModalDataDict,
|
|
MultiModalFieldConfig, MultiModalKwargs, VideoItem)
|
|
|
|
_T = TypeVar("_T")
|
|
_I = TypeVar("_I")
|
|
|
|
if TYPE_CHECKING:
|
|
import PIL.Image as PILImage
|
|
else:
|
|
PILImage = LazyLoader("PILImage", globals(), "PIL.Image")
|
|
|
|
|
|
class ModalityDataItems(ABC, Generic[_T, _I]):
|
|
"""
|
|
Represents data items for a modality in {class}`MultiModalDataItems`.
|
|
"""
|
|
|
|
def __init__(self, data: _T, modality: str) -> None:
|
|
super().__init__()
|
|
|
|
self.data = data
|
|
self.modality = modality
|
|
|
|
def __repr__(self) -> str:
|
|
return (f"{type(self).__name__}(modality={self.modality!r}, "
|
|
f"len={len(self)})")
|
|
|
|
def __len__(self) -> int:
|
|
return self.get_count()
|
|
|
|
def __getitem__(self, index: int) -> _I:
|
|
return self.get(index)
|
|
|
|
if TYPE_CHECKING:
|
|
# Auto-generated
|
|
def __iter__(self) -> Iterator[_I]:
|
|
...
|
|
|
|
@abstractmethod
|
|
def get_count(self) -> int:
|
|
"""Get the number of data items."""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def get(self, index: int) -> _I:
|
|
"""Get a data item by its index."""
|
|
raise NotImplementedError
|
|
|
|
def get_all(self) -> list[_I]:
|
|
"""Get all data items."""
|
|
return [self.get(idx) for idx in range(self.get_count())]
|
|
|
|
@abstractmethod
|
|
def get_processor_data(self) -> Mapping[str, object]:
|
|
"""Get the data to pass to the HF processor."""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def get_passthrough_data(self) -> Mapping[str, object]:
|
|
"""Get the data to pass directly to the model."""
|
|
raise NotImplementedError
|
|
|
|
|
|
class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]):
|
|
"""Base class for data items that are arranged in a list."""
|
|
|
|
def get_count(self) -> int:
|
|
return len(self.data)
|
|
|
|
def get(self, index: int) -> _T:
|
|
return self.data[index]
|
|
|
|
def get_processor_data(self) -> Mapping[str, object]:
|
|
return {f"{self.modality}s": self.data}
|
|
|
|
def get_passthrough_data(self) -> Mapping[str, object]:
|
|
return {}
|
|
|
|
|
|
class EmbeddingItems(ModalityDataItems[Union[torch.Tensor, list[torch.Tensor]],
|
|
torch.Tensor]):
|
|
"""
|
|
Base class for data items that are expressed as a batched embedding tensor,
|
|
or a list of embedding tensors (one per item).
|
|
"""
|
|
|
|
def get_count(self) -> int:
|
|
return len(self.data)
|
|
|
|
def get(self, index: int) -> torch.Tensor:
|
|
return self.data[index]
|
|
|
|
def get_processor_data(self) -> Mapping[str, object]:
|
|
return {}
|
|
|
|
def get_passthrough_data(self) -> Mapping[str, object]:
|
|
return {f"{self.modality}_embeds": self.data}
|
|
|
|
def get_feature_size(self, item_idx: int) -> int:
|
|
return len(self.get(item_idx))
|
|
|
|
|
|
class DictEmbeddingItems(ModalityDataItems[Mapping[str, torch.Tensor],
|
|
Mapping[str, torch.Tensor]]):
|
|
"""
|
|
Base class for data items that are expressed as a dictionary of tensors.
|
|
|
|
Usually, the dictionary keys correspond to the outputs of HF processor.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
data: Mapping[str, torch.Tensor],
|
|
modality: str,
|
|
required_fields: set[str],
|
|
fields_factory: Callable[
|
|
[Mapping[str, torch.Tensor]],
|
|
Mapping[str, MultiModalFieldConfig],
|
|
],
|
|
) -> None:
|
|
from transformers.feature_extraction_utils import BatchFeature
|
|
|
|
super().__init__(data, modality)
|
|
|
|
missing_required_data_keys = required_fields - data.keys()
|
|
if missing_required_data_keys:
|
|
data_keys = set(data.keys())
|
|
msg = (f"The data should contain the fields: {required_fields}, "
|
|
f"but only found the following keys: {data_keys}")
|
|
raise ValueError(msg)
|
|
|
|
fields_config = fields_factory(data)
|
|
missing_required_fields = required_fields - fields_config.keys()
|
|
if missing_required_fields:
|
|
fields = set(fields_config.keys())
|
|
msg = f"{required_fields=} should be a subset of {fields=}"
|
|
raise ValueError(msg)
|
|
|
|
self.fields_config = fields_config
|
|
self.required_fields = required_fields
|
|
|
|
self._kwargs = MultiModalKwargs.from_hf_inputs(
|
|
BatchFeature(dict(data)),
|
|
fields_config,
|
|
)
|
|
|
|
def get_count(self) -> int:
|
|
return self._kwargs.get_item_count(self.modality)
|
|
|
|
def get(self, index: int) -> Mapping[str, torch.Tensor]:
|
|
return {
|
|
k: v.data
|
|
for k, v in self._kwargs.get_item(self.modality, index).items()
|
|
}
|
|
|
|
def get_processor_data(self) -> Mapping[str, object]:
|
|
return {}
|
|
|
|
def get_passthrough_data(self) -> Mapping[str, object]:
|
|
return self.data
|
|
|
|
|
|
class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]):
|
|
|
|
def __init__(self, data: Sequence[HfAudioItem]) -> None:
|
|
super().__init__(data, "audio")
|
|
|
|
def get_audio_length(self, item_idx: int) -> int:
|
|
audio = self.get(item_idx)
|
|
return len(audio)
|
|
|
|
|
|
class AudioEmbeddingItems(EmbeddingItems):
|
|
|
|
def __init__(self, data: Union[torch.Tensor, list[torch.Tensor]]) -> None:
|
|
super().__init__(data, "audio")
|
|
|
|
|
|
class ImageSize(NamedTuple):
|
|
width: int
|
|
height: int
|
|
|
|
|
|
class ImageProcessorItems(ProcessorBatchItems[HfImageItem]):
|
|
|
|
def __init__(self, data: Sequence[HfImageItem]) -> None:
|
|
super().__init__(data, "image")
|
|
|
|
def get_image_size(self, item_idx: int) -> ImageSize:
|
|
image = self.get(item_idx)
|
|
|
|
if isinstance(image, PILImage.Image):
|
|
return ImageSize(*image.size)
|
|
if isinstance(image, (np.ndarray, torch.Tensor)):
|
|
_, h, w = image.shape
|
|
return ImageSize(w, h)
|
|
|
|
assert_never(image)
|
|
|
|
|
|
class ImageEmbeddingItems(EmbeddingItems):
|
|
|
|
def __init__(self, data: Union[torch.Tensor, list[torch.Tensor]]) -> None:
|
|
super().__init__(data, "image")
|
|
|
|
|
|
class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]):
|
|
|
|
def __init__(self, data: Sequence[HfVideoItem]) -> None:
|
|
super().__init__(data, "video")
|
|
|
|
def get_num_frames(self, item_idx: int) -> int:
|
|
return len(self.get(item_idx))
|
|
|
|
def get_frame_size(self, item_idx: int) -> ImageSize:
|
|
image = self.get(item_idx)[0] # Assume that the video isn't empty
|
|
|
|
if isinstance(image, PILImage.Image):
|
|
return ImageSize(*image.size)
|
|
if isinstance(image, (np.ndarray, torch.Tensor)):
|
|
_, h, w = image.shape
|
|
return ImageSize(w, h)
|
|
|
|
assert_never(image)
|
|
|
|
|
|
class VideoEmbeddingItems(EmbeddingItems):
|
|
|
|
def __init__(self, data: Union[torch.Tensor, list[torch.Tensor]]) -> None:
|
|
super().__init__(data, "video")
|
|
|
|
|
|
_D = TypeVar("_D", bound=ModalityDataItems[Any, Any])
|
|
|
|
|
|
class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]):
|
|
"""
|
|
As {data}`~vllm.multimodal.inputs.MultiModalDataDict`, but normalized
|
|
such that each entry corresponds to a list.
|
|
"""
|
|
|
|
def get_count(self, modality: str, *, strict: bool = True) -> int:
|
|
"""
|
|
Get the number of data items belonging to a modality.
|
|
|
|
If `strict=False`, return `0` instead of raising {exc}`KeyError`
|
|
even if the modality is not found.
|
|
"""
|
|
if modality not in self:
|
|
if strict:
|
|
available_modalities = set(self.keys())
|
|
raise KeyError(f"Modality {modality!r} not found. "
|
|
f"Available modalities: {available_modalities}")
|
|
|
|
return 0
|
|
|
|
return self[modality].get_count()
|
|
|
|
def get_all_counts(self) -> Mapping[str, int]:
|
|
"""Get the number of items belonging to each modality."""
|
|
return {m: items.get_count() for m, items in self.items()}
|
|
|
|
def get_items(
|
|
self,
|
|
modality: str,
|
|
typ: Union[type[_D], tuple[type[_D], ...]],
|
|
) -> _D:
|
|
"""
|
|
Get the data items belonging to a modality,
|
|
requiring that they belong to a certain type.
|
|
"""
|
|
if modality not in self:
|
|
available_modalities = set(self.keys())
|
|
raise KeyError(f"Modality {modality!r} not found. "
|
|
f"Available modalities: {available_modalities}")
|
|
|
|
items = self[modality]
|
|
if not isinstance(items, typ):
|
|
raise TypeError(f"Invalid type of data items for {modality=}. "
|
|
f"Expected type: {typ}, but "
|
|
f"found type: {type(items)}")
|
|
|
|
return items # type: ignore[return-value]
|
|
|
|
|
|
ModalityDataParser: TypeAlias = Callable[[ModalityData[Any]],
|
|
Optional[ModalityDataItems[Any, Any]]]
|
|
|
|
|
|
class MultiModalDataParser:
|
|
"""
|
|
Parses {data}`~vllm.multimodal.inputs.MultiModalDataDict` into
|
|
{class}`MultiModalDataItems`.
|
|
|
|
Args:
|
|
target_sr (float, optional): Enables automatic resampling of audio
|
|
items to the model's expected sampling rate.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
target_sr: Optional[float] = None,
|
|
audio_resample_method: Literal["librosa", "scipy"] = "librosa",
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.audio_resampler = AudioResampler(
|
|
target_sr=target_sr,
|
|
method=audio_resample_method,
|
|
)
|
|
|
|
def _is_embeddings(
|
|
self, data: object
|
|
) -> TypeGuard[Union[torch.Tensor, list[torch.Tensor]]]:
|
|
if isinstance(data, torch.Tensor):
|
|
return data.ndim == 3
|
|
if is_list_of(data, torch.Tensor):
|
|
return data[0].ndim == 2
|
|
|
|
return False
|
|
|
|
def _is_empty(self, data: object) -> TypeGuard[None]:
|
|
if isinstance(data, list):
|
|
return len(data) == 0
|
|
if isinstance(data, (np.ndarray, torch.Tensor)):
|
|
return data.size == 0
|
|
|
|
return False
|
|
|
|
def _get_audio_with_sr(
|
|
self,
|
|
audio: AudioItem,
|
|
) -> tuple[np.ndarray, Optional[float]]:
|
|
if isinstance(audio, tuple):
|
|
return audio
|
|
if isinstance(audio, list):
|
|
return np.array(audio), None
|
|
if isinstance(audio, np.ndarray):
|
|
return audio, None
|
|
if isinstance(audio, torch.Tensor):
|
|
return audio.numpy(), None
|
|
|
|
assert_never(audio)
|
|
|
|
def _parse_audio_data(
|
|
self,
|
|
data: ModalityData[AudioItem],
|
|
) -> Optional[ModalityDataItems[Any, Any]]:
|
|
# also check single audio item with sampling rate
|
|
if self._is_empty(data) or (isinstance(data, tuple)
|
|
and self._is_empty(data[0])):
|
|
return None
|
|
|
|
if self._is_embeddings(data):
|
|
return AudioEmbeddingItems(data)
|
|
|
|
if (is_list_of(data, float)
|
|
or isinstance(data,
|
|
(np.ndarray, torch.Tensor)) and data.ndim == 1
|
|
or isinstance(data, tuple)):
|
|
data_items = [data]
|
|
elif isinstance(data, (np.ndarray, torch.Tensor)):
|
|
data_items = [elem for elem in data]
|
|
else:
|
|
data_items = data
|
|
|
|
new_audios = list[np.ndarray]()
|
|
for data_item in data_items:
|
|
audio, orig_sr = self._get_audio_with_sr(data_item)
|
|
if orig_sr is None:
|
|
new_audio = audio
|
|
else:
|
|
new_audio = self.audio_resampler.resample(audio,
|
|
orig_sr=orig_sr)
|
|
|
|
new_audios.append(new_audio)
|
|
|
|
return AudioProcessorItems(new_audios)
|
|
|
|
def _parse_image_data(
|
|
self,
|
|
data: ModalityData[ImageItem],
|
|
) -> Optional[ModalityDataItems[Any, Any]]:
|
|
if self._is_empty(data):
|
|
return None
|
|
|
|
if self._is_embeddings(data):
|
|
return ImageEmbeddingItems(data)
|
|
|
|
if (isinstance(data, PILImage.Image)
|
|
or isinstance(data,
|
|
(np.ndarray, torch.Tensor)) and data.ndim == 3):
|
|
data_items = [data]
|
|
elif isinstance(data, (np.ndarray, torch.Tensor)):
|
|
data_items = [elem for elem in data]
|
|
else:
|
|
data_items = data
|
|
|
|
return ImageProcessorItems(data_items)
|
|
|
|
def _parse_video_data(
|
|
self,
|
|
data: ModalityData[VideoItem],
|
|
) -> Optional[ModalityDataItems[Any, Any]]:
|
|
if self._is_empty(data):
|
|
return None
|
|
|
|
if self._is_embeddings(data):
|
|
return VideoEmbeddingItems(data)
|
|
|
|
if (is_list_of(data, PILImage.Image)
|
|
or isinstance(data,
|
|
(np.ndarray, torch.Tensor)) and data.ndim == 4):
|
|
data_items = [data]
|
|
elif isinstance(data, (np.ndarray, torch.Tensor)):
|
|
data_items = [elem for elem in data]
|
|
else:
|
|
data_items = data
|
|
|
|
return VideoProcessorItems(data_items)
|
|
|
|
def _get_subparsers(self) -> Mapping[str, ModalityDataParser]:
|
|
return {
|
|
"audio": self._parse_audio_data,
|
|
"image": self._parse_image_data,
|
|
"video": self._parse_video_data,
|
|
}
|
|
|
|
def parse_mm_data(self,
|
|
mm_data: MultiModalDataDict) -> MultiModalDataItems:
|
|
subparsers = self._get_subparsers()
|
|
|
|
mm_items = MultiModalDataItems()
|
|
for k, v in mm_data.items():
|
|
if k not in subparsers:
|
|
raise ValueError(f"Unsupported modality: {k}")
|
|
|
|
# ignore empty embedding data
|
|
if (parsed_data := subparsers[k](v)) is not None:
|
|
mm_items[k] = parsed_data
|
|
|
|
return mm_items
|