mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-27 15:15:14 +08:00
[VLM] Use shared field to pass token ids to model
This commit is contained in:
parent
3b2005e1db
commit
a4ce74c14a
@ -564,8 +564,7 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
# Since there may be extra tokens in the feature placeholders,
|
||||
# we need to pass the image token ID to the model to select the
|
||||
# tokens to merge from the vision encoder outputs
|
||||
processed_outputs["image_token_id"] = [image_token_id
|
||||
] * len(image_data)
|
||||
processed_outputs["image_token_id"] = torch.tensor(image_token_id)
|
||||
|
||||
return processed_outputs
|
||||
|
||||
@ -575,13 +574,14 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
|
||||
num_images = len(image_num_patches)
|
||||
|
||||
return dict(
|
||||
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_num_patches),
|
||||
image_num_patches=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
image_token_id=MultiModalFieldConfig.batched("image"),
|
||||
image_token_id=MultiModalFieldConfig.shared("image", num_images),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
|
||||
@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
|
||||
from collections import UserDict, defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from itertools import accumulate
|
||||
from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar,
|
||||
Union, cast, final)
|
||||
@ -164,51 +165,112 @@ A dictionary containing nested tensors which have been batched via
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MultiModalFieldElem:
|
||||
"""Contains metadata and data of an item in :class:`MultiModalKwargs`."""
|
||||
field: "BaseMultiModalField"
|
||||
"""
|
||||
Represents a keyword argument corresponding to a multi-modal item
|
||||
in :class:`MultiModalKwargs`.
|
||||
"""
|
||||
|
||||
modality: str
|
||||
"""
|
||||
The modality of the corresponding multi-modal item.
|
||||
Each multi-modal item can consist of multiple keyword arguments.
|
||||
"""
|
||||
|
||||
key: str
|
||||
"""
|
||||
The key of this field in :class:`MultiModalKwargs`,
|
||||
i.e. the name of the keyword argument to be passed to the model.
|
||||
"""
|
||||
|
||||
data: NestedTensors
|
||||
"""
|
||||
The tensor data of this field in :class:`MultiModalKwargs`,
|
||||
i.e. the value of the keyword argument to be passed to the model.
|
||||
"""
|
||||
|
||||
field: "BaseMultiModalField"
|
||||
"""
|
||||
Defines how to combine the tensor data of this field with others
|
||||
in order to batch multi-modal items together for model inference.
|
||||
"""
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, self.__class__):
|
||||
return False
|
||||
|
||||
return (self.field == other.field
|
||||
and nested_tensors_equal(self.data, other.data))
|
||||
return ((self.modality, self.key) == (other.modality, other.key)
|
||||
and nested_tensors_equal(self.data, other.data)
|
||||
and type(self.field) == type(other.field)) # noqa: E721
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BaseMultiModalField(ABC):
|
||||
"""Abstract base class for a field in :class:`MultiModalKwargs`."""
|
||||
key: str
|
||||
modality: str
|
||||
"""
|
||||
Defines how to interpret tensor data belonging to a keyword argument in
|
||||
:class:`MultiModalKwargs` for multiple multi-modal items, and vice versa.
|
||||
"""
|
||||
|
||||
def _field_factory(self, *, modality: str, key: str):
|
||||
f = partial(
|
||||
MultiModalFieldElem,
|
||||
modality=modality,
|
||||
key=key,
|
||||
field=self,
|
||||
)
|
||||
|
||||
# Allow passing data as positional argument
|
||||
def factory(data: NestedTensors) -> MultiModalFieldElem:
|
||||
return f(data=data)
|
||||
|
||||
return factory
|
||||
|
||||
@abstractmethod
|
||||
def build_elems(
|
||||
self,
|
||||
modality: str,
|
||||
key: str,
|
||||
data: NestedTensors,
|
||||
) -> Sequence[MultiModalFieldElem]:
|
||||
"""
|
||||
Construct :class:`MultiModalFieldElem` instances to represent
|
||||
the provided data.
|
||||
|
||||
This is the inverse of :meth:`reduce_data`.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
||||
raise NotImplementedError
|
||||
|
||||
def _build_elem(self, data: NestedTensors) -> MultiModalFieldElem:
|
||||
return MultiModalFieldElem(self, data)
|
||||
def reduce_data(self, elems: list[MultiModalFieldElem]) -> NestedTensors:
|
||||
"""
|
||||
Merge the data from multiple instances of :class:`MultiModalFieldElem`.
|
||||
|
||||
def reduce(self, batch: list[MultiModalFieldElem]) -> MultiModalFieldElem:
|
||||
"""Merge multiple instances of :class:`MultiModalFieldElem` together."""
|
||||
fields = [item.field for item in batch]
|
||||
if len(set(fields)) > 1:
|
||||
raise ValueError(f"Cannot merge different {fields=}")
|
||||
This is the inverse of :meth:`build_elems`.
|
||||
"""
|
||||
field_types = [type(item.field) for item in elems]
|
||||
if len(set(field_types)) > 1:
|
||||
raise ValueError(f"Cannot merge different {field_types=}")
|
||||
|
||||
data = self._reduce_data([item.data for item in batch])
|
||||
|
||||
return self._build_elem(data)
|
||||
return self._reduce_data([item.data for item in elems])
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MultiModalBatchedField(BaseMultiModalField):
|
||||
"""
|
||||
A :class:`BaseMultiModalField` implementation where an element in the batch
|
||||
is obtained by indexing into the first dimension of the underlying data.
|
||||
See also:
|
||||
:func:`MultiModalFieldConfig.batched`
|
||||
"""
|
||||
|
||||
def build_elems(self, batch: NestedTensors) -> list[MultiModalFieldElem]:
|
||||
return [self._build_elem(item) for item in batch]
|
||||
def build_elems(
|
||||
self,
|
||||
modality: str,
|
||||
key: str,
|
||||
data: NestedTensors,
|
||||
) -> Sequence[MultiModalFieldElem]:
|
||||
field_factory = self._field_factory(modality=modality, key=key)
|
||||
return [field_factory(item) for item in data]
|
||||
|
||||
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
||||
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
|
||||
@ -227,16 +289,20 @@ class MultiModalBatchedField(BaseMultiModalField):
|
||||
@dataclass(frozen=True)
|
||||
class MultiModalFlatField(BaseMultiModalField):
|
||||
"""
|
||||
A :class:`BaseMultiModalField` implementation where an element in the batch
|
||||
is obtained by slicing along the first dimension of the underlying data.
|
||||
See also:
|
||||
:func:`MultiModalFieldConfig.flat`
|
||||
:func:`MultiModalFieldConfig.flat_from_sizes`
|
||||
"""
|
||||
slices: Sequence[slice]
|
||||
|
||||
def build_elems(
|
||||
self,
|
||||
batch: NestedTensors,
|
||||
slices: Sequence[slice],
|
||||
) -> list[MultiModalFieldElem]:
|
||||
return [self._build_elem(batch[slice_]) for slice_ in slices]
|
||||
modality: str,
|
||||
key: str,
|
||||
data: NestedTensors,
|
||||
) -> Sequence[MultiModalFieldElem]:
|
||||
field_factory = self._field_factory(modality=modality, key=key)
|
||||
return [field_factory(data[s]) for s in self.slices]
|
||||
|
||||
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
||||
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
|
||||
@ -252,25 +318,121 @@ class MultiModalFlatField(BaseMultiModalField):
|
||||
return [e for elem in batch for e in elem]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MultiModalSharedField(BaseMultiModalField):
|
||||
"""
|
||||
See also:
|
||||
:func:`MultiModalFieldConfig.shared`
|
||||
"""
|
||||
batch_size: int
|
||||
|
||||
def build_elems(
|
||||
self,
|
||||
modality: str,
|
||||
key: str,
|
||||
data: NestedTensors,
|
||||
) -> Sequence[MultiModalFieldElem]:
|
||||
field_factory = self._field_factory(modality=modality, key=key)
|
||||
return [field_factory(data)] * self.batch_size
|
||||
|
||||
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
||||
return batch[0]
|
||||
|
||||
|
||||
class MultiModalFieldConfig:
|
||||
|
||||
@staticmethod
|
||||
def batched(modality: str):
|
||||
"""
|
||||
Defines a field where an element in the batch is obtained by
|
||||
indexing into the first dimension of the underlying data.
|
||||
|
||||
Args:
|
||||
modality: The modality of the multi-modal item that uses this
|
||||
keyword argument.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block::
|
||||
|
||||
Input:
|
||||
Data: [[AAAA]
|
||||
[BBBB]
|
||||
[CCCC]]
|
||||
|
||||
Output:
|
||||
Element 1: [AAAA]
|
||||
Element 2: [BBBB]
|
||||
Element 3: [CCCC]
|
||||
"""
|
||||
return MultiModalFieldConfig(
|
||||
field_cls=MultiModalBatchedField,
|
||||
field=MultiModalBatchedField(),
|
||||
modality=modality,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def flat(modality: str, slices: Sequence[slice]):
|
||||
"""
|
||||
Defines a field where an element in the batch is obtained by
|
||||
slicing along the first dimension of the underlying data.
|
||||
|
||||
Args:
|
||||
modality: The modality of the multi-modal item that uses this
|
||||
keyword argument.
|
||||
slices: For each multi-modal item, a slice that is used to extract
|
||||
the data corresponding to it.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block::
|
||||
|
||||
Given:
|
||||
slices: [slice(0, 3), slice(3, 7), slice(7, 9)]
|
||||
|
||||
Input:
|
||||
Data: [AAABBBBCC]
|
||||
|
||||
Output:
|
||||
Element 1: [AAA]
|
||||
Element 2: [BBBB]
|
||||
Element 3: [CC]
|
||||
"""
|
||||
return MultiModalFieldConfig(
|
||||
field_cls=MultiModalFlatField,
|
||||
field=MultiModalFlatField(slices=slices),
|
||||
modality=modality,
|
||||
slices=slices,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def flat_from_sizes(modality: str, size_per_item: torch.Tensor):
|
||||
"""
|
||||
Defines a field where an element in the batch is obtained by
|
||||
slicing along the first dimension of the underlying data.
|
||||
|
||||
Args:
|
||||
modality: The modality of the multi-modal item that uses this
|
||||
keyword argument.
|
||||
slices: For each multi-modal item, the size of the slice that
|
||||
is used to extract the data corresponding to it.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block::
|
||||
|
||||
Given:
|
||||
size_per_item: [3, 4, 2]
|
||||
|
||||
Input:
|
||||
Data: [AAABBBBCC]
|
||||
|
||||
Output:
|
||||
Element 1: [AAA]
|
||||
Element 2: [BBBB]
|
||||
Element 3: [CC]
|
||||
|
||||
See also:
|
||||
:func:`MultiModalFieldConfig.flat`
|
||||
"""
|
||||
|
||||
slice_idxs = [0, *accumulate(size_per_item)]
|
||||
slices = [
|
||||
slice(slice_idxs[i], slice_idxs[i + 1])
|
||||
@ -279,25 +441,52 @@ class MultiModalFieldConfig:
|
||||
|
||||
return MultiModalFieldConfig.flat(modality, slices)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
field_cls: type[BaseMultiModalField],
|
||||
modality: str,
|
||||
**field_config: Any,
|
||||
) -> None:
|
||||
@staticmethod
|
||||
def shared(modality: str, batch_size: int):
|
||||
"""
|
||||
Defines a field where an element in the batch is obtained by
|
||||
taking the entirety of the underlying data.
|
||||
|
||||
This means that the data is the same for each element in the batch.
|
||||
|
||||
Args:
|
||||
modality: The modality of the multi-modal item that uses this
|
||||
keyword argument.
|
||||
batch_size: The number of multi-modal items which share this data.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block::
|
||||
|
||||
Given:
|
||||
batch_size: 4
|
||||
|
||||
Input:
|
||||
Data: [XYZ]
|
||||
|
||||
Output:
|
||||
Element 1: [XYZ]
|
||||
Element 2: [XYZ]
|
||||
Element 3: [XYZ]
|
||||
Element 4: [XYZ]
|
||||
"""
|
||||
return MultiModalFieldConfig(
|
||||
field=MultiModalSharedField(batch_size),
|
||||
modality=modality,
|
||||
)
|
||||
|
||||
def __init__(self, field: BaseMultiModalField, modality: str) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.field_cls = field_cls
|
||||
self.field = field
|
||||
self.modality = modality
|
||||
self.field_config = field_config
|
||||
|
||||
def build_elems(
|
||||
self,
|
||||
key: str,
|
||||
batch: NestedTensors,
|
||||
) -> Sequence[MultiModalFieldElem]:
|
||||
field = self.field_cls(key=key, modality=self.modality)
|
||||
return field.build_elems(batch, **self.field_config) # type: ignore
|
||||
return self.field.build_elems(self.modality, key, batch)
|
||||
|
||||
|
||||
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
|
||||
@ -308,11 +497,11 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
|
||||
|
||||
@staticmethod
|
||||
def from_elems(elems: Sequence[MultiModalFieldElem]):
|
||||
return MultiModalKwargsItem({elem.field.key: elem for elem in elems})
|
||||
return MultiModalKwargsItem({elem.key: elem for elem in elems})
|
||||
|
||||
@property
|
||||
def modality(self) -> str:
|
||||
modalities = {elem.field.modality for elem in self.data.values()}
|
||||
modalities = {elem.modality for elem in self.data.values()}
|
||||
assert len(modalities) == 1, f"Found different modalities={modalities}"
|
||||
return next(iter(modalities))
|
||||
|
||||
@ -372,7 +561,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
|
||||
elems_by_key[key].append(elem)
|
||||
|
||||
data = {
|
||||
key: elems[0].field.reduce(elems).data
|
||||
key: elems[0].field.reduce_data(elems)
|
||||
for key, elems in elems_by_key.items() if len(elems) > 0
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user