[VLM] Use shared field to pass token ids to model

This commit is contained in:
Cyrus Leung 2025-02-06 05:30:46 +08:00 committed by GitHub
parent 3b2005e1db
commit a4ce74c14a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 235 additions and 46 deletions

View File

@ -564,8 +564,7 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
# Since there may be extra tokens in the feature placeholders,
# we need to pass the image token ID to the model to select the
# tokens to merge from the vision encoder outputs
processed_outputs["image_token_id"] = [image_token_id
] * len(image_data)
processed_outputs["image_token_id"] = torch.tensor(image_token_id)
return processed_outputs
@ -575,13 +574,14 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
num_images = len(image_num_patches)
return dict(
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_patches),
image_num_patches=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
image_token_id=MultiModalFieldConfig.batched("image"),
image_token_id=MultiModalFieldConfig.shared("image", num_images),
)
def _get_prompt_replacements(

View File

@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
from collections import UserDict, defaultdict
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from functools import partial
from itertools import accumulate
from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar,
Union, cast, final)
@ -164,51 +165,112 @@ A dictionary containing nested tensors which have been batched via
@dataclass(frozen=True)
class MultiModalFieldElem:
"""Contains metadata and data of an item in :class:`MultiModalKwargs`."""
field: "BaseMultiModalField"
"""
Represents a keyword argument corresponding to a multi-modal item
in :class:`MultiModalKwargs`.
"""
modality: str
"""
The modality of the corresponding multi-modal item.
Each multi-modal item can consist of multiple keyword arguments.
"""
key: str
"""
The key of this field in :class:`MultiModalKwargs`,
i.e. the name of the keyword argument to be passed to the model.
"""
data: NestedTensors
"""
The tensor data of this field in :class:`MultiModalKwargs`,
i.e. the value of the keyword argument to be passed to the model.
"""
field: "BaseMultiModalField"
"""
Defines how to combine the tensor data of this field with others
in order to batch multi-modal items together for model inference.
"""
def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
return False
return (self.field == other.field
and nested_tensors_equal(self.data, other.data))
return ((self.modality, self.key) == (other.modality, other.key)
and nested_tensors_equal(self.data, other.data)
and type(self.field) == type(other.field)) # noqa: E721
@dataclass(frozen=True)
class BaseMultiModalField(ABC):
"""Abstract base class for a field in :class:`MultiModalKwargs`."""
key: str
modality: str
"""
Defines how to interpret tensor data belonging to a keyword argument in
:class:`MultiModalKwargs` for multiple multi-modal items, and vice versa.
"""
def _field_factory(self, *, modality: str, key: str):
f = partial(
MultiModalFieldElem,
modality=modality,
key=key,
field=self,
)
# Allow passing data as positional argument
def factory(data: NestedTensors) -> MultiModalFieldElem:
return f(data=data)
return factory
@abstractmethod
def build_elems(
self,
modality: str,
key: str,
data: NestedTensors,
) -> Sequence[MultiModalFieldElem]:
"""
Construct :class:`MultiModalFieldElem` instances to represent
the provided data.
This is the inverse of :meth:`reduce_data`.
"""
raise NotImplementedError
@abstractmethod
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
raise NotImplementedError
def _build_elem(self, data: NestedTensors) -> MultiModalFieldElem:
return MultiModalFieldElem(self, data)
def reduce_data(self, elems: list[MultiModalFieldElem]) -> NestedTensors:
"""
Merge the data from multiple instances of :class:`MultiModalFieldElem`.
def reduce(self, batch: list[MultiModalFieldElem]) -> MultiModalFieldElem:
"""Merge multiple instances of :class:`MultiModalFieldElem` together."""
fields = [item.field for item in batch]
if len(set(fields)) > 1:
raise ValueError(f"Cannot merge different {fields=}")
This is the inverse of :meth:`build_elems`.
"""
field_types = [type(item.field) for item in elems]
if len(set(field_types)) > 1:
raise ValueError(f"Cannot merge different {field_types=}")
data = self._reduce_data([item.data for item in batch])
return self._build_elem(data)
return self._reduce_data([item.data for item in elems])
@dataclass(frozen=True)
class MultiModalBatchedField(BaseMultiModalField):
"""
A :class:`BaseMultiModalField` implementation where an element in the batch
is obtained by indexing into the first dimension of the underlying data.
See also:
:func:`MultiModalFieldConfig.batched`
"""
def build_elems(self, batch: NestedTensors) -> list[MultiModalFieldElem]:
return [self._build_elem(item) for item in batch]
def build_elems(
self,
modality: str,
key: str,
data: NestedTensors,
) -> Sequence[MultiModalFieldElem]:
field_factory = self._field_factory(modality=modality, key=key)
return [field_factory(item) for item in data]
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
@ -227,16 +289,20 @@ class MultiModalBatchedField(BaseMultiModalField):
@dataclass(frozen=True)
class MultiModalFlatField(BaseMultiModalField):
"""
A :class:`BaseMultiModalField` implementation where an element in the batch
is obtained by slicing along the first dimension of the underlying data.
See also:
:func:`MultiModalFieldConfig.flat`
:func:`MultiModalFieldConfig.flat_from_sizes`
"""
slices: Sequence[slice]
def build_elems(
self,
batch: NestedTensors,
slices: Sequence[slice],
) -> list[MultiModalFieldElem]:
return [self._build_elem(batch[slice_]) for slice_ in slices]
modality: str,
key: str,
data: NestedTensors,
) -> Sequence[MultiModalFieldElem]:
field_factory = self._field_factory(modality=modality, key=key)
return [field_factory(data[s]) for s in self.slices]
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
@ -252,25 +318,121 @@ class MultiModalFlatField(BaseMultiModalField):
return [e for elem in batch for e in elem]
@dataclass(frozen=True)
class MultiModalSharedField(BaseMultiModalField):
"""
See also:
:func:`MultiModalFieldConfig.shared`
"""
batch_size: int
def build_elems(
self,
modality: str,
key: str,
data: NestedTensors,
) -> Sequence[MultiModalFieldElem]:
field_factory = self._field_factory(modality=modality, key=key)
return [field_factory(data)] * self.batch_size
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
return batch[0]
class MultiModalFieldConfig:
@staticmethod
def batched(modality: str):
"""
Defines a field where an element in the batch is obtained by
indexing into the first dimension of the underlying data.
Args:
modality: The modality of the multi-modal item that uses this
keyword argument.
Example:
.. code-block::
Input:
Data: [[AAAA]
[BBBB]
[CCCC]]
Output:
Element 1: [AAAA]
Element 2: [BBBB]
Element 3: [CCCC]
"""
return MultiModalFieldConfig(
field_cls=MultiModalBatchedField,
field=MultiModalBatchedField(),
modality=modality,
)
@staticmethod
def flat(modality: str, slices: Sequence[slice]):
"""
Defines a field where an element in the batch is obtained by
slicing along the first dimension of the underlying data.
Args:
modality: The modality of the multi-modal item that uses this
keyword argument.
slices: For each multi-modal item, a slice that is used to extract
the data corresponding to it.
Example:
.. code-block::
Given:
slices: [slice(0, 3), slice(3, 7), slice(7, 9)]
Input:
Data: [AAABBBBCC]
Output:
Element 1: [AAA]
Element 2: [BBBB]
Element 3: [CC]
"""
return MultiModalFieldConfig(
field_cls=MultiModalFlatField,
field=MultiModalFlatField(slices=slices),
modality=modality,
slices=slices,
)
@staticmethod
def flat_from_sizes(modality: str, size_per_item: torch.Tensor):
"""
Defines a field where an element in the batch is obtained by
slicing along the first dimension of the underlying data.
Args:
modality: The modality of the multi-modal item that uses this
keyword argument.
slices: For each multi-modal item, the size of the slice that
is used to extract the data corresponding to it.
Example:
.. code-block::
Given:
size_per_item: [3, 4, 2]
Input:
Data: [AAABBBBCC]
Output:
Element 1: [AAA]
Element 2: [BBBB]
Element 3: [CC]
See also:
:func:`MultiModalFieldConfig.flat`
"""
slice_idxs = [0, *accumulate(size_per_item)]
slices = [
slice(slice_idxs[i], slice_idxs[i + 1])
@ -279,25 +441,52 @@ class MultiModalFieldConfig:
return MultiModalFieldConfig.flat(modality, slices)
def __init__(
self,
field_cls: type[BaseMultiModalField],
modality: str,
**field_config: Any,
) -> None:
@staticmethod
def shared(modality: str, batch_size: int):
"""
Defines a field where an element in the batch is obtained by
taking the entirety of the underlying data.
This means that the data is the same for each element in the batch.
Args:
modality: The modality of the multi-modal item that uses this
keyword argument.
batch_size: The number of multi-modal items which share this data.
Example:
.. code-block::
Given:
batch_size: 4
Input:
Data: [XYZ]
Output:
Element 1: [XYZ]
Element 2: [XYZ]
Element 3: [XYZ]
Element 4: [XYZ]
"""
return MultiModalFieldConfig(
field=MultiModalSharedField(batch_size),
modality=modality,
)
def __init__(self, field: BaseMultiModalField, modality: str) -> None:
super().__init__()
self.field_cls = field_cls
self.field = field
self.modality = modality
self.field_config = field_config
def build_elems(
self,
key: str,
batch: NestedTensors,
) -> Sequence[MultiModalFieldElem]:
field = self.field_cls(key=key, modality=self.modality)
return field.build_elems(batch, **self.field_config) # type: ignore
return self.field.build_elems(self.modality, key, batch)
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
@ -308,11 +497,11 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
@staticmethod
def from_elems(elems: Sequence[MultiModalFieldElem]):
return MultiModalKwargsItem({elem.field.key: elem for elem in elems})
return MultiModalKwargsItem({elem.key: elem for elem in elems})
@property
def modality(self) -> str:
modalities = {elem.field.modality for elem in self.data.values()}
modalities = {elem.modality for elem in self.data.values()}
assert len(modalities) == 1, f"Found different modalities={modalities}"
return next(iter(modalities))
@ -372,7 +561,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
elems_by_key[key].append(elem)
data = {
key: elems[0].field.reduce(elems).data
key: elems[0].field.reduce_data(elems)
for key, elems in elems_by_key.items() if len(elems) > 0
}