diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 08fc659ab610f..380eb40d9eb28 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -564,8 +564,7 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]): # Since there may be extra tokens in the feature placeholders, # we need to pass the image token ID to the model to select the # tokens to merge from the vision encoder outputs - processed_outputs["image_token_id"] = [image_token_id - ] * len(image_data) + processed_outputs["image_token_id"] = torch.tensor(image_token_id) return processed_outputs @@ -575,13 +574,14 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]): hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0)) + num_images = len(image_num_patches) return dict( pixel_values_flat=MultiModalFieldConfig.flat_from_sizes( "image", image_num_patches), image_num_patches=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"), - image_token_id=MultiModalFieldConfig.batched("image"), + image_token_id=MultiModalFieldConfig.shared("image", num_images), ) def _get_prompt_replacements( diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 8e4af7f88f911..2f2535f368cff 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod from collections import UserDict, defaultdict from collections.abc import Mapping, Sequence from dataclasses import dataclass +from functools import partial from itertools import accumulate from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar, Union, cast, final) @@ -164,51 +165,112 @@ A dictionary containing nested tensors which have been batched via @dataclass(frozen=True) class MultiModalFieldElem: - """Contains metadata and data of an item in :class:`MultiModalKwargs`.""" - field: "BaseMultiModalField" + """ + Represents a keyword argument corresponding to a multi-modal item + in :class:`MultiModalKwargs`. + """ + + modality: str + """ + The modality of the corresponding multi-modal item. + Each multi-modal item can consist of multiple keyword arguments. + """ + + key: str + """ + The key of this field in :class:`MultiModalKwargs`, + i.e. the name of the keyword argument to be passed to the model. + """ + data: NestedTensors + """ + The tensor data of this field in :class:`MultiModalKwargs`, + i.e. the value of the keyword argument to be passed to the model. + """ + + field: "BaseMultiModalField" + """ + Defines how to combine the tensor data of this field with others + in order to batch multi-modal items together for model inference. + """ def __eq__(self, other: object) -> bool: if not isinstance(other, self.__class__): return False - return (self.field == other.field - and nested_tensors_equal(self.data, other.data)) + return ((self.modality, self.key) == (other.modality, other.key) + and nested_tensors_equal(self.data, other.data) + and type(self.field) == type(other.field)) # noqa: E721 @dataclass(frozen=True) class BaseMultiModalField(ABC): - """Abstract base class for a field in :class:`MultiModalKwargs`.""" - key: str - modality: str + """ + Defines how to interpret tensor data belonging to a keyword argument in + :class:`MultiModalKwargs` for multiple multi-modal items, and vice versa. + """ + + def _field_factory(self, *, modality: str, key: str): + f = partial( + MultiModalFieldElem, + modality=modality, + key=key, + field=self, + ) + + # Allow passing data as positional argument + def factory(data: NestedTensors) -> MultiModalFieldElem: + return f(data=data) + + return factory + + @abstractmethod + def build_elems( + self, + modality: str, + key: str, + data: NestedTensors, + ) -> Sequence[MultiModalFieldElem]: + """ + Construct :class:`MultiModalFieldElem` instances to represent + the provided data. + + This is the inverse of :meth:`reduce_data`. + """ + raise NotImplementedError @abstractmethod def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: raise NotImplementedError - def _build_elem(self, data: NestedTensors) -> MultiModalFieldElem: - return MultiModalFieldElem(self, data) + def reduce_data(self, elems: list[MultiModalFieldElem]) -> NestedTensors: + """ + Merge the data from multiple instances of :class:`MultiModalFieldElem`. - def reduce(self, batch: list[MultiModalFieldElem]) -> MultiModalFieldElem: - """Merge multiple instances of :class:`MultiModalFieldElem` together.""" - fields = [item.field for item in batch] - if len(set(fields)) > 1: - raise ValueError(f"Cannot merge different {fields=}") + This is the inverse of :meth:`build_elems`. + """ + field_types = [type(item.field) for item in elems] + if len(set(field_types)) > 1: + raise ValueError(f"Cannot merge different {field_types=}") - data = self._reduce_data([item.data for item in batch]) - - return self._build_elem(data) + return self._reduce_data([item.data for item in elems]) @dataclass(frozen=True) class MultiModalBatchedField(BaseMultiModalField): """ - A :class:`BaseMultiModalField` implementation where an element in the batch - is obtained by indexing into the first dimension of the underlying data. + See also: + :func:`MultiModalFieldConfig.batched` """ - def build_elems(self, batch: NestedTensors) -> list[MultiModalFieldElem]: - return [self._build_elem(item) for item in batch] + def build_elems( + self, + modality: str, + key: str, + data: NestedTensors, + ) -> Sequence[MultiModalFieldElem]: + field_factory = self._field_factory(modality=modality, key=key) + return [field_factory(item) for item in data] def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"): @@ -227,16 +289,20 @@ class MultiModalBatchedField(BaseMultiModalField): @dataclass(frozen=True) class MultiModalFlatField(BaseMultiModalField): """ - A :class:`BaseMultiModalField` implementation where an element in the batch - is obtained by slicing along the first dimension of the underlying data. + See also: + :func:`MultiModalFieldConfig.flat` + :func:`MultiModalFieldConfig.flat_from_sizes` """ + slices: Sequence[slice] def build_elems( self, - batch: NestedTensors, - slices: Sequence[slice], - ) -> list[MultiModalFieldElem]: - return [self._build_elem(batch[slice_]) for slice_ in slices] + modality: str, + key: str, + data: NestedTensors, + ) -> Sequence[MultiModalFieldElem]: + field_factory = self._field_factory(modality=modality, key=key) + return [field_factory(data[s]) for s in self.slices] def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"): @@ -252,25 +318,121 @@ class MultiModalFlatField(BaseMultiModalField): return [e for elem in batch for e in elem] +@dataclass(frozen=True) +class MultiModalSharedField(BaseMultiModalField): + """ + See also: + :func:`MultiModalFieldConfig.shared` + """ + batch_size: int + + def build_elems( + self, + modality: str, + key: str, + data: NestedTensors, + ) -> Sequence[MultiModalFieldElem]: + field_factory = self._field_factory(modality=modality, key=key) + return [field_factory(data)] * self.batch_size + + def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: + return batch[0] + + class MultiModalFieldConfig: @staticmethod def batched(modality: str): + """ + Defines a field where an element in the batch is obtained by + indexing into the first dimension of the underlying data. + + Args: + modality: The modality of the multi-modal item that uses this + keyword argument. + + Example: + + .. code-block:: + + Input: + Data: [[AAAA] + [BBBB] + [CCCC]] + + Output: + Element 1: [AAAA] + Element 2: [BBBB] + Element 3: [CCCC] + """ return MultiModalFieldConfig( - field_cls=MultiModalBatchedField, + field=MultiModalBatchedField(), modality=modality, ) @staticmethod def flat(modality: str, slices: Sequence[slice]): + """ + Defines a field where an element in the batch is obtained by + slicing along the first dimension of the underlying data. + + Args: + modality: The modality of the multi-modal item that uses this + keyword argument. + slices: For each multi-modal item, a slice that is used to extract + the data corresponding to it. + + Example: + + .. code-block:: + + Given: + slices: [slice(0, 3), slice(3, 7), slice(7, 9)] + + Input: + Data: [AAABBBBCC] + + Output: + Element 1: [AAA] + Element 2: [BBBB] + Element 3: [CC] + """ return MultiModalFieldConfig( - field_cls=MultiModalFlatField, + field=MultiModalFlatField(slices=slices), modality=modality, - slices=slices, ) @staticmethod def flat_from_sizes(modality: str, size_per_item: torch.Tensor): + """ + Defines a field where an element in the batch is obtained by + slicing along the first dimension of the underlying data. + + Args: + modality: The modality of the multi-modal item that uses this + keyword argument. + slices: For each multi-modal item, the size of the slice that + is used to extract the data corresponding to it. + + Example: + + .. code-block:: + + Given: + size_per_item: [3, 4, 2] + + Input: + Data: [AAABBBBCC] + + Output: + Element 1: [AAA] + Element 2: [BBBB] + Element 3: [CC] + + See also: + :func:`MultiModalFieldConfig.flat` + """ + slice_idxs = [0, *accumulate(size_per_item)] slices = [ slice(slice_idxs[i], slice_idxs[i + 1]) @@ -279,25 +441,52 @@ class MultiModalFieldConfig: return MultiModalFieldConfig.flat(modality, slices) - def __init__( - self, - field_cls: type[BaseMultiModalField], - modality: str, - **field_config: Any, - ) -> None: + @staticmethod + def shared(modality: str, batch_size: int): + """ + Defines a field where an element in the batch is obtained by + taking the entirety of the underlying data. + + This means that the data is the same for each element in the batch. + + Args: + modality: The modality of the multi-modal item that uses this + keyword argument. + batch_size: The number of multi-modal items which share this data. + + Example: + + .. code-block:: + + Given: + batch_size: 4 + + Input: + Data: [XYZ] + + Output: + Element 1: [XYZ] + Element 2: [XYZ] + Element 3: [XYZ] + Element 4: [XYZ] + """ + return MultiModalFieldConfig( + field=MultiModalSharedField(batch_size), + modality=modality, + ) + + def __init__(self, field: BaseMultiModalField, modality: str) -> None: super().__init__() - self.field_cls = field_cls + self.field = field self.modality = modality - self.field_config = field_config def build_elems( self, key: str, batch: NestedTensors, ) -> Sequence[MultiModalFieldElem]: - field = self.field_cls(key=key, modality=self.modality) - return field.build_elems(batch, **self.field_config) # type: ignore + return self.field.build_elems(self.modality, key, batch) class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]): @@ -308,11 +497,11 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]): @staticmethod def from_elems(elems: Sequence[MultiModalFieldElem]): - return MultiModalKwargsItem({elem.field.key: elem for elem in elems}) + return MultiModalKwargsItem({elem.key: elem for elem in elems}) @property def modality(self) -> str: - modalities = {elem.field.modality for elem in self.data.values()} + modalities = {elem.modality for elem in self.data.values()} assert len(modalities) == 1, f"Found different modalities={modalities}" return next(iter(modalities)) @@ -372,7 +561,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]): elems_by_key[key].append(elem) data = { - key: elems[0].field.reduce(elems).data + key: elems[0].field.reduce_data(elems) for key, elems in elems_by_key.items() if len(elems) > 0 }