mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-04 00:46:32 +08:00
[VLM] Abstract out multi-modal data parsing in merged processor (#11620)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
b12e87f942
commit
8d9b6721e7
@ -356,7 +356,7 @@ steps:
|
||||
- pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model'
|
||||
- pytest -v -s models/embedding/language -m 'not core_model'
|
||||
|
||||
- label: Multi-Modal Models Test (Standard) # 28min
|
||||
- label: Multi-Modal Models Test (Standard) # 40min
|
||||
#mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -372,7 +372,7 @@ steps:
|
||||
- pytest -v -s models/encoder_decoder/language -m core_model
|
||||
- pytest -v -s models/encoder_decoder/vision_language -m core_model
|
||||
|
||||
- label: Multi-Modal Models Test (Extended) 1 # 1h16m
|
||||
- label: Multi-Modal Models Test (Extended) 1 # 48m
|
||||
optional: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
|
||||
@ -33,7 +33,7 @@ from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalData, MultiModalKwargs,
|
||||
from vllm.multimodal.inputs import (ModalityData, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
||||
@ -54,7 +54,7 @@ def calculate_image_placeholder(vision_config):
|
||||
|
||||
def mm_input_mapper_for_glmv(
|
||||
ctx: InputContext,
|
||||
data: MultiModalData[object],
|
||||
data: ModalityData[object],
|
||||
) -> Dict:
|
||||
model_config = ctx.model_config
|
||||
tokenizer = cached_get_tokenizer(
|
||||
|
||||
@ -20,11 +20,13 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems,
|
||||
MultiModalFieldConfig, MultiModalInputsV2,
|
||||
MultiModalKwargs, NestedTensors)
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputsV2, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.parse import ImageProcessorItems
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
ProcessorInputs, PromptReplacement,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement,
|
||||
full_groupby_modality)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
@ -179,7 +181,9 @@ class LlavaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
assert isinstance(vision_config, PixtralVisionConfig)
|
||||
|
||||
def get_replacement_pixtral(item_idx: int):
|
||||
image_size = mm_items.get_image_size(item_idx)
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
image_size = images.get_image_size(item_idx)
|
||||
|
||||
(
|
||||
num_width_tokens,
|
||||
num_height_tokens,
|
||||
@ -591,8 +595,8 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
|
||||
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
|
||||
|
||||
mm_items = self._get_mm_items(mm_data)
|
||||
mm_item_counts = mm_items.get_item_counts()
|
||||
mm_items = self._to_mm_items(mm_data)
|
||||
mm_item_counts = mm_items.get_all_counts()
|
||||
mm_kwargs = result["mm_kwargs"]
|
||||
|
||||
# We reimplement the functionality of MLlavaProcessor from
|
||||
|
||||
@ -32,12 +32,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.models.clip import CLIPVisionModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems,
|
||||
MultiModalFieldConfig, MultiModalInputsV2,
|
||||
MultiModalKwargs, NestedTensors,
|
||||
PlaceholderRange)
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputsV2, MultiModalKwargs,
|
||||
NestedTensors, PlaceholderRange)
|
||||
from vllm.multimodal.parse import ImageProcessorItems
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
ProcessorInputs, PromptReplacement,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement,
|
||||
_BoundPromptReplacement,
|
||||
_PlaceholderInfo)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@ -381,7 +382,9 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
|
||||
assert isinstance(bos_token_id, int)
|
||||
|
||||
def get_replacement_phi3v(item_idx: int):
|
||||
image_size = mm_items.get_image_size(item_idx)
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
image_size = images.get_image_size(item_idx)
|
||||
|
||||
num_tokens = image_processor.calc_num_image_tokens_from_image_size(
|
||||
width=image_size.width,
|
||||
height=image_size.height,
|
||||
@ -389,12 +392,14 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
return [_IMAGE_TOKEN_ID] * num_tokens + [bos_token_id]
|
||||
|
||||
num_images = mm_items.get_count("image", strict=False)
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=image_token,
|
||||
replacement=get_replacement_phi3v,
|
||||
) for image_token in image_tokens[:len(mm_items.images)]
|
||||
) for image_token in image_tokens[:num_images]
|
||||
]
|
||||
|
||||
def _apply_prompt_replacements(
|
||||
|
||||
@ -20,8 +20,8 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
|
||||
from functools import cached_property
|
||||
from typing import (Any, Iterable, List, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
|
||||
Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -38,10 +38,12 @@ from vllm.inputs import InputContext
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataItems, MultiModalFieldConfig,
|
||||
MultiModalKwargs, NestedTensors)
|
||||
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.parse import MultiModalDataParser
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
ProcessorInputs, PromptReplacement)
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
@ -99,15 +101,9 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
|
||||
def _get_feature_extractor(self) -> WhisperFeatureExtractor:
|
||||
return self._get_hf_processor().feature_extractor # type: ignore
|
||||
|
||||
def _get_hf_mm_data(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
# resample audio to the model's sampling rate
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
feature_extractor = self._get_feature_extractor()
|
||||
mm_items.resample_audios(feature_extractor.sampling_rate)
|
||||
|
||||
return super()._get_hf_mm_data(mm_items)
|
||||
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
|
||||
@ -25,7 +25,6 @@ from functools import cached_property, partial
|
||||
from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
|
||||
Set, Tuple, Type, TypedDict, Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@ -55,15 +54,16 @@ from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems,
|
||||
from vllm.multimodal.inputs import (ImageItem, ModalityData,
|
||||
MultiModalFieldConfig, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
NestedTensors, VideoItem)
|
||||
from vllm.multimodal.parse import ModalityDataItems, MultiModalDataParser
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
ProcessorInputs, PromptReplacement)
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.config import uses_mrope
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, get_vit_attn_backend,
|
||||
@ -719,61 +719,81 @@ get_max_qwen2_vl_video_tokens = partial(get_max_qwen2_vl_mm_tokens,
|
||||
data_type_key="video")
|
||||
|
||||
|
||||
class Qwen2VLMultiModalDataItems(MultiModalDataItems):
|
||||
class Qwen2EmbeddingItems(ModalityDataItems[dict[str, torch.Tensor],
|
||||
dict[str, torch.Tensor]]):
|
||||
|
||||
@staticmethod
|
||||
def from_dict(data: MultiModalDataDict) -> "MultiModalDataItems":
|
||||
"""
|
||||
Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`.
|
||||
"""
|
||||
multi_data = Qwen2VLMultiModalDataItems()
|
||||
def __init__(self, data: dict, modality: str) -> None:
|
||||
super().__init__(data)
|
||||
|
||||
for k, v in data.items():
|
||||
# TODO: Make a separate modality for embedding inputs
|
||||
# to avoid confusion
|
||||
# yapf: disable
|
||||
if k == "video":
|
||||
# Special case since even a single item can be a list
|
||||
multi_data[k] = ( # type: ignore[index]
|
||||
v if (
|
||||
isinstance(v, (dict, torch.Tensor)) # type: ignore[assignment]
|
||||
or is_list_of(v, list)
|
||||
or isinstance(v[0], (np.ndarray, torch.Tensor))
|
||||
and v[0].ndim == 4
|
||||
) else [v]
|
||||
)
|
||||
elif k in ("image", "audio"):
|
||||
multi_data[k] = ( # type: ignore[index]
|
||||
v if isinstance(v, (dict, torch.Tensor, list)) else [v]
|
||||
)
|
||||
else:
|
||||
multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index]
|
||||
# yapf: enable
|
||||
self.modality = modality
|
||||
|
||||
return multi_data
|
||||
grid_thw = data[f"{modality}_grid_thw"]
|
||||
slice_idxs = [0] + grid_thw.prod(-1).cumsum_(0).tolist()
|
||||
self._slices = [
|
||||
slice(slice_idxs[i], slice_idxs[i + 1])
|
||||
for i in range(len(grid_thw))
|
||||
]
|
||||
|
||||
def get_item_counts(self) -> Mapping[str, int]:
|
||||
return {
|
||||
m: (
|
||||
len(items[f"{m}_grid_thw"]) # type: ignore
|
||||
if isinstance(items, dict) else len(items))
|
||||
for m, items in self.items()
|
||||
}
|
||||
def __repr__(self) -> str:
|
||||
return (f"{type(self).__name__}(modality={self.modality!r})")
|
||||
|
||||
def has_embedding_inputs(self) -> bool:
|
||||
return any(
|
||||
isinstance(items, dict) or any(
|
||||
isinstance(item, torch.Tensor) for item in items)
|
||||
for items in self.values())
|
||||
def get_count(self) -> int:
|
||||
return len(self.data[f"{self.modality}_grid_thw"])
|
||||
|
||||
def get(self, index: int) -> dict[str, torch.Tensor]:
|
||||
out = {}
|
||||
for k, v in self.data.items():
|
||||
if v != f"{self.modality}_grid_thw":
|
||||
v = v[self._slices[index]]
|
||||
|
||||
out[k] = v
|
||||
|
||||
return out
|
||||
|
||||
def get_processor_data(self) -> Mapping[str, object]:
|
||||
return {}
|
||||
|
||||
def get_passthrough_data(self) -> Mapping[str, object]:
|
||||
return self.data
|
||||
|
||||
|
||||
class Qwen2ImageEmbeddingItems(Qwen2EmbeddingItems):
|
||||
|
||||
def __init__(self, data: dict) -> None:
|
||||
super().__init__(data, "image")
|
||||
|
||||
|
||||
class Qwen2VideoEmbeddingItems(Qwen2EmbeddingItems):
|
||||
|
||||
def __init__(self, data: dict) -> None:
|
||||
super().__init__(data, "video")
|
||||
|
||||
|
||||
class Qwen2MultiModalDataParser(MultiModalDataParser):
|
||||
|
||||
def _parse_image_data(
|
||||
self,
|
||||
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
|
||||
) -> ModalityDataItems[Any, Any]:
|
||||
if isinstance(data, dict):
|
||||
return Qwen2EmbeddingItems(data, modality="image")
|
||||
|
||||
return super()._parse_image_data(data)
|
||||
|
||||
def _parse_video_data(
|
||||
self,
|
||||
data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
|
||||
) -> ModalityDataItems[Any, Any]:
|
||||
if isinstance(data, dict):
|
||||
return Qwen2EmbeddingItems(data, modality="video")
|
||||
|
||||
return super()._parse_video_data(data)
|
||||
|
||||
|
||||
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def _get_mm_items(
|
||||
self,
|
||||
mm_data: MultiModalDataDict,
|
||||
) -> MultiModalDataItems:
|
||||
return Qwen2VLMultiModalDataItems.from_dict(mm_data)
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
return Qwen2MultiModalDataParser()
|
||||
|
||||
def _get_hf_processor(
|
||||
self,
|
||||
@ -796,35 +816,6 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
return hf_processor
|
||||
|
||||
def _get_hf_mm_data(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
processor_data = dict[str, Any]()
|
||||
passthrough_data = dict[str, Any]()
|
||||
|
||||
for k, v in mm_items.items():
|
||||
# TODO: Make a separate modality for embedding inputs
|
||||
# to avoid confusion
|
||||
if k in ("image", "video", "audio"):
|
||||
if isinstance(v, dict):
|
||||
# Pass through embedding inputs (dict)
|
||||
passthrough_data.update(v)
|
||||
elif isinstance(v, torch.Tensor) and v.ndim == 3:
|
||||
# Pass through embedding inputs (single)
|
||||
passthrough_data[f"{k}_embeds"] = [v]
|
||||
elif (is_list_of(v, torch.Tensor) and len(v) > 0
|
||||
and v[0].ndim == 2):
|
||||
# Pass through embedding inputs (multi)
|
||||
passthrough_data[f"{k}_embeds"] = v
|
||||
elif len(v) > 0:
|
||||
# Map keys to plural form, e.g.: image -> images
|
||||
processor_data[f"{k}s"] = v
|
||||
else:
|
||||
processor_data[k] = v
|
||||
|
||||
return processor_data, passthrough_data
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
|
||||
@ -3,8 +3,8 @@
|
||||
|
||||
import math
|
||||
from functools import cached_property, lru_cache
|
||||
from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set,
|
||||
Tuple, TypedDict, Union)
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -24,10 +24,12 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataItems, MultiModalFieldConfig,
|
||||
MultiModalKwargs, NestedTensors)
|
||||
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.parse import MultiModalDataParser
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
ProcessorInputs, PromptReplacement)
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
||||
from vllm.utils import is_list_of
|
||||
@ -85,15 +87,9 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
|
||||
hf_processor = self._get_hf_processor()
|
||||
return hf_processor.audio_processor.feature_extractor # type: ignore
|
||||
|
||||
def _get_hf_mm_data(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
# resample audio to the model's sampling rate
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
feature_extractor = self._get_feature_extractor()
|
||||
mm_items.resample_audios(feature_extractor.sampling_rate)
|
||||
|
||||
return super()._get_hf_mm_data(mm_items)
|
||||
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
from .base import MultiModalPlaceholderMap, MultiModalPlugin
|
||||
from .inputs import (BatchedTensorInputs, MultiModalData,
|
||||
MultiModalDataBuiltins, MultiModalDataDict,
|
||||
MultiModalKwargs, MultiModalPlaceholderDict,
|
||||
NestedTensors)
|
||||
from .inputs import (BatchedTensorInputs, ModalityData, MultiModalDataBuiltins,
|
||||
MultiModalDataDict, MultiModalKwargs,
|
||||
MultiModalPlaceholderDict, NestedTensors)
|
||||
from .registry import MultiModalRegistry
|
||||
|
||||
MULTIMODAL_REGISTRY = MultiModalRegistry()
|
||||
@ -16,7 +15,7 @@ See also:
|
||||
|
||||
__all__ = [
|
||||
"BatchedTensorInputs",
|
||||
"MultiModalData",
|
||||
"ModalityData",
|
||||
"MultiModalDataBuiltins",
|
||||
"MultiModalDataDict",
|
||||
"MultiModalKwargs",
|
||||
|
||||
@ -9,7 +9,7 @@ from vllm.inputs.registry import InputContext
|
||||
from vllm.utils import PlaceholderModule
|
||||
|
||||
from .base import MediaIO, MultiModalPlugin
|
||||
from .inputs import AudioItem, MultiModalData, MultiModalKwargs
|
||||
from .inputs import AudioItem, ModalityData, MultiModalKwargs
|
||||
|
||||
try:
|
||||
import librosa
|
||||
@ -31,7 +31,7 @@ class AudioPlugin(MultiModalPlugin):
|
||||
def _default_input_mapper(
|
||||
self,
|
||||
ctx: InputContext,
|
||||
data: MultiModalData[AudioItem],
|
||||
data: ModalityData[AudioItem],
|
||||
**mm_processor_kwargs,
|
||||
) -> MultiModalKwargs:
|
||||
raise NotImplementedError("There is no default audio input mapper")
|
||||
|
||||
@ -15,12 +15,12 @@ if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.sequence import SequenceGroupMetadata
|
||||
|
||||
from .inputs import (MultiModalData, MultiModalDataDict, MultiModalKwargs,
|
||||
from .inputs import (ModalityData, MultiModalDataDict, MultiModalKwargs,
|
||||
PlaceholderRange)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
MultiModalInputMapper = Callable[[InputContext, MultiModalData[object]],
|
||||
MultiModalInputMapper = Callable[[InputContext, ModalityData[object]],
|
||||
MultiModalKwargs]
|
||||
"""
|
||||
Return a dictionary to be passed as keyword arguments to
|
||||
@ -69,7 +69,7 @@ class MultiModalPlugin(ABC):
|
||||
def _default_input_mapper(
|
||||
self,
|
||||
ctx: InputContext,
|
||||
data: MultiModalData[Any],
|
||||
data: ModalityData[Any],
|
||||
**mm_processor_kwargs,
|
||||
) -> MultiModalKwargs:
|
||||
"""
|
||||
@ -118,7 +118,7 @@ class MultiModalPlugin(ABC):
|
||||
def map_input(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
data: MultiModalData[Any],
|
||||
data: ModalityData[Any],
|
||||
mm_processor_kwargs: Optional[dict[str, Any]],
|
||||
) -> MultiModalKwargs:
|
||||
"""
|
||||
|
||||
@ -13,7 +13,7 @@ from vllm.transformers_utils.processor import get_image_processor
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .base import MediaIO, MultiModalPlugin
|
||||
from .inputs import ImageItem, MultiModalData, MultiModalKwargs
|
||||
from .inputs import ImageItem, ModalityData, MultiModalKwargs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
@ -44,7 +44,7 @@ class ImagePlugin(MultiModalPlugin):
|
||||
def _default_input_mapper(
|
||||
self,
|
||||
ctx: InputContext,
|
||||
data: MultiModalData[ImageItem],
|
||||
data: ModalityData[ImageItem],
|
||||
**mm_processor_kwargs,
|
||||
) -> MultiModalKwargs:
|
||||
model_config = ctx.model_config
|
||||
|
||||
@ -2,53 +2,74 @@ from abc import ABC, abstractmethod
|
||||
from collections import UserDict, defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import (Any, Literal, NamedTuple, TypedDict, TypeVar, Union, cast,
|
||||
final)
|
||||
from typing import Any, Literal, TypedDict, TypeVar, Union, cast, final
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.types
|
||||
from PIL.Image import Image
|
||||
from transformers import BatchFeature
|
||||
from typing_extensions import NotRequired, TypeAlias, assert_never
|
||||
from typing_extensions import NotRequired, TypeAlias
|
||||
|
||||
from vllm.utils import JSONTree, is_list_of, json_map_leaves
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
# yapf: disable
|
||||
ImageItem: TypeAlias = Union[Image, np.ndarray, torch.Tensor]
|
||||
HfImageItem: TypeAlias = Union[Image, np.ndarray, torch.Tensor]
|
||||
"""
|
||||
A :class:`transformers.image_utils.ImageInput` representing a single image
|
||||
item, which can be passed to a HuggingFace :code:`ImageProcessor`.
|
||||
"""
|
||||
|
||||
VideoItem: TypeAlias = Union[
|
||||
list[Image],
|
||||
np.ndarray,
|
||||
torch.Tensor,
|
||||
list[np.ndarray],
|
||||
list[torch.Tensor],
|
||||
]
|
||||
HfVideoItem: TypeAlias = Union[list[Image], np.ndarray, torch.Tensor,
|
||||
list[np.ndarray], list[torch.Tensor]]
|
||||
"""
|
||||
A :class:`transformers.image_utils.VideoInput` representing a single video
|
||||
item, which can be passed to a HuggingFace :code:`VideoProcessor`.
|
||||
"""
|
||||
|
||||
AudioItem: TypeAlias = Union[
|
||||
np.ndarray,
|
||||
list[float],
|
||||
# `(audio, sampling_rate)`: If the audio's sampling rate is different
|
||||
# from that expected by the model, we need to resample it.
|
||||
tuple[np.ndarray, float],
|
||||
]
|
||||
HfAudioItem: TypeAlias = Union[list[float], np.ndarray, torch.Tensor]
|
||||
"""
|
||||
Represents a single audio
|
||||
item, which can be passed to a HuggingFace :code:`AudioProcessor`.
|
||||
"""
|
||||
# yapf: enable
|
||||
|
||||
MultiModalData: TypeAlias = Union[_T, list[_T]]
|
||||
ImageItem: TypeAlias = Union[HfImageItem, torch.Tensor]
|
||||
"""
|
||||
A :class:`transformers.image_utils.ImageInput` representing a single image
|
||||
item, which can be passed to a HuggingFace :code:`ImageProcessor`.
|
||||
|
||||
Alternatively, a 3-D tensor or batch of 2-D tensors,
|
||||
which are treated as image embeddings;
|
||||
these are directly passed to the model without HF processing.
|
||||
"""
|
||||
|
||||
VideoItem: TypeAlias = Union[HfVideoItem, torch.Tensor]
|
||||
"""
|
||||
A :class:`transformers.image_utils.VideoInput` representing a single video
|
||||
item, which can be passed to a HuggingFace :code:`VideoProcessor`.
|
||||
|
||||
Alternatively, a 3-D tensor or batch of 2-D tensors,
|
||||
which are treated as video embeddings;
|
||||
these are directly passed to the model without HF processing.
|
||||
"""
|
||||
|
||||
AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float],
|
||||
torch.Tensor]
|
||||
"""
|
||||
Represents a single audio
|
||||
item, which can be passed to a HuggingFace :code:`AudioProcessor`.
|
||||
|
||||
Alternatively, a tuple `(audio, sampling_rate)`, where the sampling rate
|
||||
is different from that expected by the model;
|
||||
these are resampled to the model's sampling rate before being processed by HF.
|
||||
|
||||
Alternatively, a 3-D tensor or batch of 2-D tensors,
|
||||
which are treated as audio embeddings;
|
||||
these are directly passed to the model without HF processing.
|
||||
"""
|
||||
|
||||
ModalityData: TypeAlias = Union[_T, list[_T]]
|
||||
"""
|
||||
Either a single data item, or a list of data items.
|
||||
|
||||
@ -61,17 +82,17 @@ The number of data items allowed per modality is restricted by
|
||||
class MultiModalDataBuiltins(TypedDict, total=False):
|
||||
"""Type annotations for modality types predefined by vLLM."""
|
||||
|
||||
image: MultiModalData[ImageItem]
|
||||
image: ModalityData[ImageItem]
|
||||
"""The input image(s)."""
|
||||
|
||||
video: MultiModalData[VideoItem]
|
||||
video: ModalityData[VideoItem]
|
||||
"""The input video(s)."""
|
||||
|
||||
audio: MultiModalData[AudioItem]
|
||||
audio: ModalityData[AudioItem]
|
||||
"""The input audio(s)."""
|
||||
|
||||
|
||||
MultiModalDataDict: TypeAlias = Mapping[str, MultiModalData[Any]]
|
||||
MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]]
|
||||
"""
|
||||
A dictionary containing an entry for each modality type to input.
|
||||
|
||||
@ -83,123 +104,6 @@ Note:
|
||||
"""
|
||||
|
||||
|
||||
class ImageSize(NamedTuple):
|
||||
width: int
|
||||
height: int
|
||||
|
||||
|
||||
class MultiModalDataItems(UserDict[str, list[Any]]):
|
||||
"""
|
||||
As :class:`MultiModalDataDict`, but normalized such that each entry
|
||||
corresponds to a list.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def from_dict(data: MultiModalDataDict) -> "MultiModalDataItems":
|
||||
"""
|
||||
Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`.
|
||||
"""
|
||||
multi_data = MultiModalDataItems()
|
||||
|
||||
for k, v in data.items():
|
||||
# TODO: Make a separate modality for embedding inputs
|
||||
# to avoid confusion
|
||||
# yapf: disable
|
||||
if k == "video":
|
||||
# Special case since even a single item can be a list
|
||||
multi_data[k] = ( # type: ignore[index]
|
||||
v if (
|
||||
isinstance(v, torch.Tensor)
|
||||
or is_list_of(v, list)
|
||||
or isinstance(v[0], (np.ndarray, torch.Tensor))
|
||||
and v[0].ndim == 4
|
||||
) else [v]
|
||||
)
|
||||
elif k in ("image", "audio"):
|
||||
multi_data[k] = ( # type: ignore[index]
|
||||
v if isinstance(v, (torch.Tensor, list)) else [v]
|
||||
)
|
||||
else:
|
||||
multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index]
|
||||
# yapf: enable
|
||||
|
||||
return multi_data
|
||||
|
||||
# NOTE: When a field (e.g. `images`) doesn't exist, directly appending to
|
||||
# `self.images` doesn't update this dictionary, which may be confusing
|
||||
# We annotate the getter methods as `Sequence` to prevent others from
|
||||
# trying to update the list in this way
|
||||
@property
|
||||
def images(self) -> Sequence[ImageItem]:
|
||||
return self.get("image", [])
|
||||
|
||||
@property
|
||||
def videos(self) -> Sequence[VideoItem]:
|
||||
return self.get("video", [])
|
||||
|
||||
@property
|
||||
def audios(self) -> Sequence[AudioItem]:
|
||||
return self.get("audio", [])
|
||||
|
||||
def get_item_counts(self) -> Mapping[str, int]:
|
||||
return {m: len(items) for m, items in self.items()}
|
||||
|
||||
def has_embedding_inputs(self) -> bool:
|
||||
return any(
|
||||
any(isinstance(item, torch.Tensor) for item in items)
|
||||
for items in self.values())
|
||||
|
||||
def get_image_size(self, item_idx: int) -> ImageSize:
|
||||
image = self.images[item_idx]
|
||||
|
||||
if isinstance(image, Image):
|
||||
return ImageSize(*image.size)
|
||||
if isinstance(image, (np.ndarray, torch.Tensor)):
|
||||
_, h, w = image.shape
|
||||
return ImageSize(w, h)
|
||||
|
||||
assert_never(image)
|
||||
|
||||
def get_audio_with_sr(
|
||||
self,
|
||||
item_idx: int,
|
||||
*,
|
||||
default_sr: float,
|
||||
) -> tuple[np.ndarray, float]:
|
||||
audio = self.audios[item_idx]
|
||||
|
||||
if isinstance(audio, tuple):
|
||||
return audio
|
||||
if isinstance(audio, list):
|
||||
return np.array(audio), default_sr
|
||||
if isinstance(audio, np.ndarray):
|
||||
return audio, default_sr
|
||||
|
||||
assert_never(audio)
|
||||
|
||||
def resample_audios(self, new_sr: float, *, drop_sr: bool = True) -> None:
|
||||
"""
|
||||
If :code:`drop_sr=True`, the audio items in this dictionary are updated
|
||||
to be NumPy arrays which implicitly means that their sampling rate is
|
||||
the same as the model's expected sampling rate; otherwise, they remain
|
||||
as :code:`(audio, new_sr)` tuples.
|
||||
"""
|
||||
# Avoid circular import
|
||||
from .audio import resample_audio
|
||||
|
||||
if not self.audios:
|
||||
return
|
||||
|
||||
new_audios = []
|
||||
for item_idx in range(len(self.audios)):
|
||||
audio, sr = self.get_audio_with_sr(item_idx, default_sr=new_sr)
|
||||
audio = resample_audio(audio, orig_sr=sr, target_sr=new_sr)
|
||||
|
||||
new_audios.append(audio if drop_sr else (audio, new_sr))
|
||||
|
||||
self["audio"] = new_audios
|
||||
|
||||
|
||||
class PlaceholderRange(TypedDict):
|
||||
"""
|
||||
Placeholder location information for multi-modal data.
|
||||
@ -436,7 +340,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
|
||||
) -> "MultiModalKwargs":
|
||||
data = {
|
||||
key: items[0].field.reduce(items).data
|
||||
for key, items in items_by_key.items()
|
||||
for key, items in items_by_key.items() if len(items) > 0
|
||||
}
|
||||
|
||||
return MultiModalKwargs(data,
|
||||
@ -567,6 +471,11 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
|
||||
Get the keyword arguments corresponding to an item identified by
|
||||
its modality and index.
|
||||
"""
|
||||
if modality not in self._keys_by_modality:
|
||||
available_modalities = set(self._keys_by_modality.keys())
|
||||
raise KeyError(f"Modality {modality!r} not found. "
|
||||
f"Available modalities: {available_modalities}")
|
||||
|
||||
keys_to_gather = self._keys_by_modality[modality]
|
||||
|
||||
return {
|
||||
|
||||
344
vllm/multimodal/parse.py
Normal file
344
vllm/multimodal/parse.py
Normal file
@ -0,0 +1,344 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import UserDict
|
||||
from collections.abc import Callable, Iterator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
from typing_extensions import TypeAlias, TypeGuard, assert_never
|
||||
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .audio import resample_audio
|
||||
from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem,
|
||||
ImageItem, ModalityData, MultiModalDataDict,
|
||||
NestedTensors, VideoItem)
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_I = TypeVar("_I")
|
||||
|
||||
|
||||
class ModalityDataItems(ABC, Generic[_T, _I]):
|
||||
|
||||
def __init__(self, data: _T) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.data = data
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.get_count()
|
||||
|
||||
def __getitem__(self, index: int) -> _I:
|
||||
return self.get(index)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Auto-generated
|
||||
def __iter__(self) -> Iterator[_I]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_count(self) -> int:
|
||||
"""Get the number of data items."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get(self, index: int) -> _I:
|
||||
"""Get a data item by its index."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_all(self) -> list[_I]:
|
||||
"""Get all data items."""
|
||||
return [self.get(idx) for idx in range(self.get_count())]
|
||||
|
||||
@abstractmethod
|
||||
def get_processor_data(self) -> Mapping[str, object]:
|
||||
"""Get the data to pass to the HF processor."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_passthrough_data(self) -> Mapping[str, object]:
|
||||
"""Get the data to pass directly to the model."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]):
|
||||
|
||||
def __init__(self, data: Sequence[_T], modality: str) -> None:
|
||||
super().__init__(data)
|
||||
|
||||
self.modality = modality
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"{type(self).__name__}(modality={self.modality!r})")
|
||||
|
||||
def get_count(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
def get(self, index: int) -> _T:
|
||||
return self.data[index]
|
||||
|
||||
def get_processor_data(self) -> Mapping[str, object]:
|
||||
return {f"{self.modality}s": self.data}
|
||||
|
||||
def get_passthrough_data(self) -> Mapping[str, object]:
|
||||
return {}
|
||||
|
||||
|
||||
class EmbeddingItems(ModalityDataItems[NestedTensors, torch.Tensor]):
|
||||
|
||||
def __init__(self, data: NestedTensors, modality: str) -> None:
|
||||
super().__init__(data)
|
||||
|
||||
self.modality = modality
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"{type(self).__name__}(modality={self.modality!r})")
|
||||
|
||||
def get_count(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
def get(self, index: int) -> object:
|
||||
return self.data[index]
|
||||
|
||||
def get_processor_data(self) -> Mapping[str, object]:
|
||||
return {}
|
||||
|
||||
def get_passthrough_data(self) -> Mapping[str, object]:
|
||||
return {f"{self.modality}_embeds": self.data}
|
||||
|
||||
|
||||
class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]):
|
||||
|
||||
def __init__(self, data: Sequence[HfAudioItem]) -> None:
|
||||
super().__init__(data, "audio")
|
||||
|
||||
|
||||
class AudioEmbeddingItems(EmbeddingItems):
|
||||
|
||||
def __init__(self, data: NestedTensors) -> None:
|
||||
super().__init__(data, "audio")
|
||||
|
||||
|
||||
class ImageSize(NamedTuple):
|
||||
width: int
|
||||
height: int
|
||||
|
||||
|
||||
class ImageProcessorItems(ProcessorBatchItems[HfImageItem]):
|
||||
|
||||
def __init__(self, data: Sequence[HfImageItem]) -> None:
|
||||
super().__init__(data, "image")
|
||||
|
||||
def get_image_size(self, item_idx: int) -> ImageSize:
|
||||
image = self.get(item_idx)
|
||||
|
||||
if isinstance(image, Image):
|
||||
return ImageSize(*image.size)
|
||||
if isinstance(image, (np.ndarray, torch.Tensor)):
|
||||
_, h, w = image.shape
|
||||
return ImageSize(w, h)
|
||||
|
||||
assert_never(image)
|
||||
|
||||
|
||||
class ImageEmbeddingItems(EmbeddingItems):
|
||||
|
||||
def __init__(self, data: NestedTensors) -> None:
|
||||
super().__init__(data, "image")
|
||||
|
||||
|
||||
class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]):
|
||||
|
||||
def __init__(self, data: Sequence[HfVideoItem]) -> None:
|
||||
super().__init__(data, "video")
|
||||
|
||||
|
||||
class VideoEmbeddingItems(EmbeddingItems):
|
||||
|
||||
def __init__(self, data: NestedTensors) -> None:
|
||||
super().__init__(data, "video")
|
||||
|
||||
|
||||
_D = TypeVar("_D", bound=ModalityDataItems[Any, Any])
|
||||
|
||||
|
||||
class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]):
|
||||
"""
|
||||
As :class:`MultiModalDataDict`, but normalized such that each entry
|
||||
corresponds to a list.
|
||||
"""
|
||||
|
||||
def get_count(self, modality: str, *, strict: bool = True) -> int:
|
||||
"""
|
||||
Get the number of data items belonging to a modality.
|
||||
|
||||
If `strict=False`, return `0` instead of raising :exc:`KeyError`
|
||||
even if the modality is not found.
|
||||
"""
|
||||
if modality not in self:
|
||||
if strict:
|
||||
available_modalities = set(self.keys())
|
||||
raise KeyError(f"Modality {modality!r} not found. "
|
||||
f"Available modalities: {available_modalities}")
|
||||
|
||||
return 0
|
||||
|
||||
return self[modality].get_count()
|
||||
|
||||
def get_all_counts(self) -> Mapping[str, int]:
|
||||
"""Get the number of items belonging to each modality."""
|
||||
return {m: items.get_count() for m, items in self.items()}
|
||||
|
||||
def get_items(
|
||||
self,
|
||||
modality: str,
|
||||
typ: type[_D],
|
||||
) -> _D:
|
||||
"""
|
||||
Get the data items belonging to a modality,
|
||||
requiring that they belong to a certain type.
|
||||
"""
|
||||
if modality not in self:
|
||||
available_modalities = set(self.keys())
|
||||
raise KeyError(f"Modality {modality!r} not found. "
|
||||
f"Available modalities: {available_modalities}")
|
||||
|
||||
items = self[modality]
|
||||
if not isinstance(items, typ):
|
||||
raise TypeError(f"Invalid type of data items for {modality=}. "
|
||||
f"Expected type: {typ}, but "
|
||||
f"found type: {type(items)}")
|
||||
|
||||
return items
|
||||
|
||||
|
||||
ModalityDataParser: TypeAlias = Callable[[ModalityData[Any]],
|
||||
ModalityDataItems[Any, Any]]
|
||||
|
||||
|
||||
class MultiModalDataParser:
|
||||
"""
|
||||
Parses :class:`MultiModalDataDict` into :class:`MultiModalDataItems`.
|
||||
"""
|
||||
|
||||
def __init__(self, *, target_sr: Optional[float] = None) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.target_sr = target_sr
|
||||
|
||||
def _is_embeddings(self, data: object) -> TypeGuard[NestedTensors]:
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data.ndim == 3
|
||||
if is_list_of(data, torch.Tensor):
|
||||
return len(data) == 0 or data[0].ndim == 2
|
||||
|
||||
return False
|
||||
|
||||
def _get_audio_with_sr(
|
||||
self,
|
||||
audio: AudioItem,
|
||||
) -> tuple[np.ndarray, Optional[float]]:
|
||||
if isinstance(audio, tuple):
|
||||
return audio
|
||||
if isinstance(audio, list):
|
||||
return np.array(audio), None
|
||||
if isinstance(audio, np.ndarray):
|
||||
return audio, None
|
||||
if isinstance(audio, torch.Tensor):
|
||||
return audio.numpy(), None
|
||||
|
||||
assert_never(audio)
|
||||
|
||||
def _parse_audio_data(
|
||||
self,
|
||||
data: ModalityData[AudioItem],
|
||||
) -> ModalityDataItems[Any, Any]:
|
||||
if self._is_embeddings(data):
|
||||
return AudioEmbeddingItems(data)
|
||||
|
||||
if (is_list_of(data, float)
|
||||
or isinstance(data,
|
||||
(np.ndarray, torch.Tensor)) and data.ndim == 1
|
||||
or isinstance(data, tuple)):
|
||||
data_items = [data]
|
||||
elif isinstance(data, (np.ndarray, torch.Tensor)):
|
||||
data_items = [elem for elem in data]
|
||||
else:
|
||||
data_items = data
|
||||
|
||||
new_audios = list[np.ndarray]()
|
||||
for data_item in data_items:
|
||||
audio, orig_sr = self._get_audio_with_sr(data_item)
|
||||
if orig_sr is None:
|
||||
new_audio = audio
|
||||
else:
|
||||
target_sr = self.target_sr
|
||||
if target_sr is None:
|
||||
raise RuntimeError(
|
||||
"Audio resampling is not supported when "
|
||||
"`target_sr` is not provided")
|
||||
|
||||
new_audio = resample_audio(audio,
|
||||
orig_sr=orig_sr,
|
||||
target_sr=target_sr)
|
||||
|
||||
new_audios.append(new_audio)
|
||||
|
||||
return AudioProcessorItems(new_audios)
|
||||
|
||||
def _parse_image_data(
|
||||
self,
|
||||
data: ModalityData[ImageItem],
|
||||
) -> ModalityDataItems[Any, Any]:
|
||||
if self._is_embeddings(data):
|
||||
return ImageEmbeddingItems(data)
|
||||
|
||||
if (isinstance(data, Image)
|
||||
or isinstance(data,
|
||||
(np.ndarray, torch.Tensor)) and data.ndim == 3):
|
||||
data_items = [data]
|
||||
elif isinstance(data, (np.ndarray, torch.Tensor)):
|
||||
data_items = [elem for elem in data]
|
||||
else:
|
||||
data_items = data
|
||||
|
||||
return ImageProcessorItems(data_items)
|
||||
|
||||
def _parse_video_data(
|
||||
self,
|
||||
data: ModalityData[VideoItem],
|
||||
) -> ModalityDataItems[Any, Any]:
|
||||
if self._is_embeddings(data):
|
||||
return VideoEmbeddingItems(data)
|
||||
|
||||
if (is_list_of(data, Image)
|
||||
or isinstance(data,
|
||||
(np.ndarray, torch.Tensor)) and data.ndim == 4):
|
||||
data_items = [data]
|
||||
elif isinstance(data, (np.ndarray, torch.Tensor)):
|
||||
data_items = [elem for elem in data]
|
||||
else:
|
||||
data_items = data
|
||||
|
||||
return VideoProcessorItems(data_items)
|
||||
|
||||
def _get_subparsers(self) -> Mapping[str, ModalityDataParser]:
|
||||
return {
|
||||
"audio": self._parse_audio_data,
|
||||
"image": self._parse_image_data,
|
||||
"video": self._parse_video_data,
|
||||
}
|
||||
|
||||
def parse_mm_data(self,
|
||||
mm_data: MultiModalDataDict) -> MultiModalDataItems:
|
||||
subparsers = self._get_subparsers()
|
||||
|
||||
mm_items = MultiModalDataItems()
|
||||
for k, v in mm_data.items():
|
||||
if k not in subparsers:
|
||||
raise ValueError(f"Unsupported modality: {k}")
|
||||
|
||||
mm_items[k] = subparsers[k](v)
|
||||
|
||||
return mm_items
|
||||
@ -15,11 +15,12 @@ from transformers import BatchFeature, ProcessorMixin
|
||||
from vllm.inputs import DummyData, InputProcessingContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import LRUCache, flatten_2d_lists, full_groupby, is_list_of
|
||||
from vllm.utils import LRUCache, flatten_2d_lists, full_groupby
|
||||
|
||||
from .inputs import (MultiModalDataDict, MultiModalDataItems,
|
||||
MultiModalFieldConfig, MultiModalFieldItem,
|
||||
MultiModalInputsV2, MultiModalKwargs, PlaceholderRange)
|
||||
from .inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalFieldItem, MultiModalInputsV2, MultiModalKwargs,
|
||||
PlaceholderRange)
|
||||
from .parse import MultiModalDataItems, MultiModalDataParser
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -621,6 +622,16 @@ class BaseMultiModalProcessor(ABC):
|
||||
) -> MultiModalInputsV2:
|
||||
return self.apply(prompt, mm_data, hf_processor_mm_kwargs)
|
||||
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
"""
|
||||
Construct a data parser to preprocess multi-modal data items
|
||||
before passing them to :meth:`_get_hf_mm_data`.
|
||||
|
||||
You can support additional modalities by creating a subclass
|
||||
of :class:`MultiModalDataParser` that has additional subparsers.
|
||||
"""
|
||||
return MultiModalDataParser()
|
||||
|
||||
def _get_hf_processor(self) -> ProcessorMixin:
|
||||
"""
|
||||
Subclasses can add keyword arguments to this method to accept
|
||||
@ -631,11 +642,16 @@ class BaseMultiModalProcessor(ABC):
|
||||
def _get_tokenizer(self) -> AnyTokenizer:
|
||||
return self.ctx.tokenizer
|
||||
|
||||
def _get_mm_items(
|
||||
def _to_mm_items(
|
||||
self,
|
||||
mm_data: MultiModalDataDict,
|
||||
) -> MultiModalDataItems:
|
||||
return MultiModalDataItems.from_dict(mm_data)
|
||||
"""
|
||||
Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`
|
||||
before passing them to :meth:`_get_hf_mm_data`.
|
||||
"""
|
||||
parser = self._get_data_parser()
|
||||
return parser.parse_mm_data(mm_data)
|
||||
|
||||
@abstractmethod
|
||||
def _get_mm_fields_config(
|
||||
@ -680,22 +696,9 @@ class BaseMultiModalProcessor(ABC):
|
||||
processor_data = dict[str, Any]()
|
||||
passthrough_data = dict[str, Any]()
|
||||
|
||||
for k, v in mm_items.items():
|
||||
# TODO: Make a separate modality for embedding inputs
|
||||
# to avoid confusion
|
||||
if k in ("image", "video", "audio"):
|
||||
if isinstance(v, torch.Tensor) and v.ndim == 3:
|
||||
# Pass through embedding inputs (single)
|
||||
passthrough_data[f"{k}_embeds"] = [v]
|
||||
elif (is_list_of(v, torch.Tensor) and len(v) > 0
|
||||
and v[0].ndim == 2):
|
||||
# Pass through embedding inputs (multi)
|
||||
passthrough_data[f"{k}_embeds"] = v
|
||||
elif len(v) > 0:
|
||||
# Map keys to plural form, e.g.: image -> images
|
||||
processor_data[f"{k}s"] = v
|
||||
else:
|
||||
processor_data[k] = v
|
||||
for items in mm_items.values():
|
||||
processor_data.update(items.get_processor_data())
|
||||
passthrough_data.update(items.get_passthrough_data())
|
||||
|
||||
return processor_data, passthrough_data
|
||||
|
||||
@ -756,7 +759,7 @@ class BaseMultiModalProcessor(ABC):
|
||||
cached items; instead, we rely on our own prompt replacement logic
|
||||
for the full text.
|
||||
"""
|
||||
mm_missing_counts = mm_missing_data_items.get_item_counts()
|
||||
mm_missing_counts = mm_missing_data_items.get_all_counts()
|
||||
|
||||
prompt_ids, _ = self._apply_hf_processor(
|
||||
prompt_text=prompt_text,
|
||||
@ -789,7 +792,8 @@ class BaseMultiModalProcessor(ABC):
|
||||
cache = self.cache
|
||||
model_id = self.ctx.model_config.model
|
||||
|
||||
if cache is None or mm_data_items.has_embedding_inputs():
|
||||
_, passthrough_data = self._get_hf_mm_data(mm_data_items)
|
||||
if cache is None or passthrough_data:
|
||||
return self._apply_hf_processor(
|
||||
prompt_text=prompt_text,
|
||||
mm_items=mm_data_items,
|
||||
@ -812,7 +816,7 @@ class BaseMultiModalProcessor(ABC):
|
||||
modality: [mm_data_items[modality][idx] for idx in idxs]
|
||||
for modality, idxs in mm_missing_idxs.items()
|
||||
}
|
||||
mm_missing_data_items = self._get_mm_items(mm_missing_data)
|
||||
mm_missing_data_items = self._to_mm_items(mm_missing_data)
|
||||
|
||||
prompt_ids, mm_missing_kwargs = self._apply_hf_processor_missing(
|
||||
prompt_text=prompt_text,
|
||||
@ -852,7 +856,7 @@ class BaseMultiModalProcessor(ABC):
|
||||
mm_merged_field_items[modality] = merged_modal_items_lst
|
||||
|
||||
if self.enable_sanity_checks:
|
||||
mm_missing_counts = mm_missing_data_items.get_item_counts()
|
||||
mm_missing_counts = mm_missing_data_items.get_all_counts()
|
||||
assert all(
|
||||
item_count == mm_missing_counts[modality]
|
||||
for modality, item_count in mm_missing_next_idx.items()), dict(
|
||||
@ -865,7 +869,7 @@ class BaseMultiModalProcessor(ABC):
|
||||
)
|
||||
|
||||
if self.enable_sanity_checks:
|
||||
mm_item_counts = mm_data_items.get_item_counts()
|
||||
mm_item_counts = mm_data_items.get_all_counts()
|
||||
|
||||
for modality, item_count in mm_item_counts.items():
|
||||
for item_idx in range(item_count):
|
||||
@ -958,7 +962,7 @@ class BaseMultiModalProcessor(ABC):
|
||||
3. Extract information about the placeholder tokens from the
|
||||
processed token IDs.
|
||||
"""
|
||||
mm_items = self._get_mm_items(mm_data)
|
||||
mm_items = self._to_mm_items(mm_data)
|
||||
|
||||
prompt_ids, mm_kwargs = self._cached_apply_hf_processor(
|
||||
prompt_text,
|
||||
@ -975,7 +979,7 @@ class BaseMultiModalProcessor(ABC):
|
||||
|
||||
# If HF processor already inserts placeholder tokens,
|
||||
# there is no need for us to insert them
|
||||
mm_item_counts = mm_items.get_item_counts()
|
||||
mm_item_counts = mm_items.get_all_counts()
|
||||
all_placeholders = self._find_placeholders(prompt_repls, prompt_ids,
|
||||
mm_item_counts)
|
||||
|
||||
|
||||
@ -15,7 +15,7 @@ from vllm.transformers_utils.processor import get_video_processor
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.utils import PlaceholderModule, is_list_of
|
||||
|
||||
from .base import MediaIO, MultiModalData
|
||||
from .base import MediaIO, ModalityData
|
||||
from .image import ImageMediaIO, ImagePlugin
|
||||
from .inputs import MultiModalKwargs, VideoItem
|
||||
|
||||
@ -54,7 +54,7 @@ class VideoPlugin(ImagePlugin):
|
||||
def _default_input_mapper(
|
||||
self,
|
||||
ctx: InputContext,
|
||||
data: MultiModalData[VideoItem],
|
||||
data: ModalityData[VideoItem],
|
||||
**mm_processor_kwargs,
|
||||
) -> MultiModalKwargs:
|
||||
model_config = ctx.model_config
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user