[Model] Move multimodal_cpu_fields definition to field config (#30181)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-12-06 21:40:02 +08:00 committed by GitHub
parent 21bb323542
commit 671427efbf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 141 additions and 95 deletions

View File

@ -28,7 +28,7 @@ def _dummy_elem(modality: str, key: str, size: int):
modality=modality,
key=key,
data=torch.empty((size,), dtype=torch.int8),
field=MultiModalSharedField(1),
field=MultiModalSharedField(batch_size=1),
)

View File

@ -51,7 +51,7 @@ def _dummy_elem(
modality=modality,
key=key,
data=data,
field=MultiModalSharedField(1),
field=MultiModalSharedField(batch_size=1),
)

View File

@ -104,22 +104,31 @@ class MyRequest(msgspec.Struct):
def test_multimodal_kwargs():
e1 = MultiModalFieldElem(
"audio", "a0", torch.zeros(1000, dtype=torch.bfloat16), MultiModalBatchedField()
"audio",
"a0",
torch.zeros(1000, dtype=torch.bfloat16),
MultiModalBatchedField(),
)
e2 = MultiModalFieldElem(
"video",
"v0",
[torch.zeros(1000, dtype=torch.int8) for _ in range(4)],
MultiModalFlatField([[slice(1, 2, 3), slice(4, 5, 6)], [slice(None, 2)]], 0),
MultiModalFlatField(
slices=[[slice(1, 2, 3), slice(4, 5, 6)], [slice(None, 2)]],
dim=0,
),
)
e3 = MultiModalFieldElem(
"image", "i0", torch.zeros(1000, dtype=torch.int32), MultiModalSharedField(4)
"image",
"i0",
torch.zeros(1000, dtype=torch.int32),
MultiModalSharedField(batch_size=4),
)
e4 = MultiModalFieldElem(
"image",
"i1",
torch.zeros(1000, dtype=torch.int32),
MultiModalFlatField([slice(1, 2, 3), slice(4, 5, 6)], 2),
MultiModalFlatField(slices=[slice(1, 2, 3), slice(4, 5, 6)], dim=2),
)
audio = MultiModalKwargsItem.from_elems([e1])
video = MultiModalKwargsItem.from_elems([e2])
@ -138,8 +147,8 @@ def test_multimodal_kwargs():
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
# expected total encoding length, should be 14306, +-20 for minor changes
assert 14275 <= total_len <= 14325
# expected total encoding length, should be 14395, +-20 for minor changes
assert 14375 <= total_len <= 14425
decoded = decoder.decode(encoded).mm[0]
assert isinstance(decoded, MultiModalKwargsItems)

View File

@ -787,10 +787,10 @@ class Glm4vVisionTransformer(nn.Module):
def forward(
self,
x: torch.Tensor,
grid_thw: list[list[int]],
grid_thw: torch.Tensor | list[list[int]],
) -> torch.Tensor:
# Convert grid_thw to tensor (always expecting list format now)
grid_thw = torch.tensor(grid_thw, device=x.device, dtype=torch.long)
if isinstance(grid_thw, list):
grid_thw = torch.tensor(grid_thw, dtype=torch.int32)
# patchify
x = x.to(device=self.device, dtype=self.dtype)
@ -805,7 +805,8 @@ class Glm4vVisionTransformer(nn.Module):
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
# pre-compute max_seqlen for attn mask to reduce cuMemcpy operations
max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
@ -1548,7 +1549,6 @@ class Glm4vForConditionalGeneration(
) -> tuple[torch.Tensor, ...]:
grid_thw = image_input["image_grid_thw"]
assert grid_thw.ndim == 2
grid_thw_list = grid_thw.tolist()
if image_input["type"] == "image_embeds":
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
@ -1559,12 +1559,10 @@ class Glm4vForConditionalGeneration(
self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d"
)
else:
image_embeds = self.visual(pixel_values, grid_thw=grid_thw.tolist())
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
merge_size = self.visual.spatial_merge_size
sizes = (
torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
// (merge_size * merge_size)
).tolist()
sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
return image_embeds.split(sizes)
def _process_video_input(
@ -1572,7 +1570,6 @@ class Glm4vForConditionalGeneration(
) -> tuple[torch.Tensor, ...]:
grid_thw = video_input["video_grid_thw"]
assert grid_thw.ndim == 2
grid_thw_list = grid_thw.tolist()
if video_input["type"] == "video_embeds":
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
@ -1588,15 +1585,11 @@ class Glm4vForConditionalGeneration(
rope_type="rope_3d",
)
else:
video_embeds = self.visual(
pixel_values_videos, grid_thw=grid_thw.tolist()
)
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
# Split concatenated embeddings for each video item.
merge_size = self.visual.spatial_merge_size
sizes = (
torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
// (merge_size * merge_size)
).tolist()
sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
return video_embeds.split(sizes)
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:

View File

@ -563,7 +563,7 @@ def _hunyuan_vl_field_config(hf_inputs: Mapping[str, torch.Tensor]):
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes),
image_embeds=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes),
image_grid_thw=MultiModalFieldConfig.batched("image"),
image_grid_thw=MultiModalFieldConfig.batched("image", keep_on_cpu=True),
)
@ -786,8 +786,6 @@ class HunYuanVLForConditionalGeneration(
SupportsQuant,
SupportsXDRoPE,
):
multimodal_cpu_fields = {"image_grid_thw"}
# To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={

View File

@ -84,9 +84,9 @@ class SupportsMultiModal(Protocol):
`vllm.multimodal.utils.group_mm_kwargs_by_modality` to use.
"""
multimodal_cpu_fields: ClassVar[Set[str]] = frozenset()
multimodal_cpu_fields: ClassVar[Set[str] | None] = None
"""
A set indicating CPU-only multimodal fields.
[DEPRECATED] A set indicating CPU-only multimodal fields.
"""
_processor_factory: ClassVar[_ProcessorFactories]
@ -279,6 +279,15 @@ def supports_multimodal(
"please remove the override from your model."
)
multimodal_cpu_fields = getattr(model, "multimodal_cpu_fields", None)
if multimodal_cpu_fields is not None:
raise ValueError(
"`multimodal_cpu_fields` is no longer effective, "
"please set `keep_on_cpu=True` in `MultiModalFieldConfig` "
"(refer to https://github.com/vllm-project/vllm/pull/30181), "
"and then remove the override from your model."
)
return res

View File

@ -201,8 +201,6 @@ class OpenCUADummyInputsBuilder(Qwen2VLDummyInputsBuilder):
dummy_inputs=OpenCUADummyInputsBuilder,
)
class OpenCUAForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
multimodal_cpu_fields = {"image_grid_thw"}
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],

View File

@ -1039,8 +1039,6 @@ class Qwen2_5_VLForConditionalGeneration(
SupportsMultiModalPruning,
SupportsMRoPE,
):
multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"}
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],

View File

@ -811,14 +811,14 @@ def _create_qwen2vl_field_factory(
image_embeds=MultiModalFieldConfig.flat_from_sizes(
"image", image_embed_grid_sizes
),
image_grid_thw=MultiModalFieldConfig.batched("image"),
image_grid_thw=MultiModalFieldConfig.batched("image", keep_on_cpu=True),
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes
),
video_embeds=MultiModalFieldConfig.flat_from_sizes(
"video", video_embed_grid_sizes
),
video_grid_thw=MultiModalFieldConfig.batched("video"),
video_grid_thw=MultiModalFieldConfig.batched("video", keep_on_cpu=True),
)
return _qwen2vl_field_config
@ -1131,8 +1131,6 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo])
class Qwen2VLForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
):
multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"}
# To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
@ -1393,9 +1391,11 @@ class Qwen2VLForConditionalGeneration(
else:
pixel_values_videos = video_input["pixel_values_videos"]
if self.use_data_parallel:
grid_thw_list = grid_thw.tolist()
return run_dp_sharded_mrope_vision_model(
self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d"
self.visual,
pixel_values_videos,
grid_thw.tolist(),
rope_type="rope_3d",
)
else:
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)

View File

@ -984,14 +984,14 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo])
image_embeds=MultiModalFieldConfig.flat_from_sizes(
"image", image_grid_sizes
),
image_grid_thw=MultiModalFieldConfig.batched("image"),
image_grid_thw=MultiModalFieldConfig.batched("image", keep_on_cpu=True),
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes
),
video_embeds=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes
),
video_grid_thw=MultiModalFieldConfig.batched("video"),
video_grid_thw=MultiModalFieldConfig.batched("video", keep_on_cpu=True),
)
def _get_prompt_updates(
@ -1190,8 +1190,6 @@ class Qwen3VLForConditionalGeneration(
SupportsMRoPE,
SupportsEagle3,
):
multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"}
packed_modules_mapping = {
"qkv_proj": [
"q_proj",

View File

@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from collections import UserDict, defaultdict
from collections.abc import Mapping, Sequence, Set
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from functools import partial
from itertools import accumulate
@ -223,6 +223,23 @@ def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
return a == b
def _nested_tensors_h2d(
tensors: NestedTensors,
device: torch.types.Device,
) -> NestedTensors:
if device is None:
return tensors
return json_map_leaves(
(
lambda x: x.to(device=device, non_blocking=True)
if isinstance(x, torch.Tensor)
else x
),
tensors,
)
BatchedTensorInputs: TypeAlias = dict[str, NestedTensors]
"""
A dictionary containing nested tensors which have been batched via
@ -334,7 +351,7 @@ class MultiModalFieldElem:
) # noqa: E721
@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class BaseMultiModalField(ABC):
"""
Defines how to interpret tensor data belonging to a keyword argument in
@ -342,6 +359,12 @@ class BaseMultiModalField(ABC):
multi-modal items, and vice versa.
"""
keep_on_cpu: bool = False
"""
If `True`, then this field is excluded from being moved to the accelerator
when `MultiModalKwargsItems.get_data()` is called to batch the data.
"""
def _field_factory(self, *, modality: str, key: str):
f = partial(
MultiModalFieldElem,
@ -386,6 +409,7 @@ class BaseMultiModalField(ABC):
self,
elems: list[MultiModalFieldElem],
*,
device: torch.types.Device = None,
pin_memory: bool = False,
) -> NestedTensors:
"""
@ -399,11 +423,17 @@ class BaseMultiModalField(ABC):
if len(set(field_types)) > 1:
raise ValueError(f"Cannot merge different {field_types=}")
if device is not None and self.keep_on_cpu:
device = "cpu"
if pin_memory and self.keep_on_cpu:
pin_memory = False
batch = [elem.data for elem in elems]
return self._reduce_data(batch, pin_memory=pin_memory)
out = self._reduce_data(batch, pin_memory=pin_memory)
return _nested_tensors_h2d(out, device=device)
@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class MultiModalBatchedField(BaseMultiModalField):
"""
Info:
@ -445,7 +475,7 @@ class MultiModalBatchedField(BaseMultiModalField):
return batch
@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class MultiModalFlatField(BaseMultiModalField):
"""
Info:
@ -505,7 +535,7 @@ class MultiModalFlatField(BaseMultiModalField):
return [e for elem in batch for e in elem]
@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class MultiModalSharedField(BaseMultiModalField):
"""
Info:
@ -532,9 +562,10 @@ class MultiModalSharedField(BaseMultiModalField):
return batch[0]
@dataclass(frozen=True)
class MultiModalFieldConfig:
@staticmethod
def batched(modality: str):
def batched(modality: str, *, keep_on_cpu: bool = False):
"""
Defines a field where an element in the batch is obtained by
indexing into the first dimension of the underlying data.
@ -542,6 +573,7 @@ class MultiModalFieldConfig:
Args:
modality: The modality of the multi-modal item that uses this
keyword argument.
keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
Example:
@ -558,7 +590,7 @@ class MultiModalFieldConfig:
```
"""
return MultiModalFieldConfig(
field=MultiModalBatchedField(),
field=MultiModalBatchedField(keep_on_cpu=keep_on_cpu),
modality=modality,
)
@ -567,6 +599,8 @@ class MultiModalFieldConfig:
modality: str,
slices: Sequence[slice] | Sequence[Sequence[slice]],
dim: int = 0,
*,
keep_on_cpu: bool = False,
):
"""
Defines a field where an element in the batch is obtained by
@ -579,6 +613,7 @@ class MultiModalFieldConfig:
slices (dim>0) that is used to extract the data corresponding
to it.
dim: The dimension to extract data, default to 0.
keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
Example:
@ -613,12 +648,22 @@ class MultiModalFieldConfig:
```
"""
return MultiModalFieldConfig(
field=MultiModalFlatField(slices=slices, dim=dim),
field=MultiModalFlatField(
slices=slices,
dim=dim,
keep_on_cpu=keep_on_cpu,
),
modality=modality,
)
@staticmethod
def flat_from_sizes(modality: str, size_per_item: "torch.Tensor", dim: int = 0):
def flat_from_sizes(
modality: str,
size_per_item: "torch.Tensor",
dim: int = 0,
*,
keep_on_cpu: bool = False,
):
"""
Defines a field where an element in the batch is obtained by
slicing along the first dimension of the underlying data.
@ -629,6 +674,7 @@ class MultiModalFieldConfig:
size_per_item: For each multi-modal item, the size of the slice
that is used to extract the data corresponding to it.
dim: The dimension to slice, default to 0.
keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
Example:
@ -676,10 +722,20 @@ class MultiModalFieldConfig:
for i in range(len(size_per_item))
]
return MultiModalFieldConfig.flat(modality, slices, dim=dim)
return MultiModalFieldConfig.flat(
modality,
slices,
dim=dim,
keep_on_cpu=keep_on_cpu,
)
@staticmethod
def shared(modality: str, batch_size: int):
def shared(
modality: str,
batch_size: int,
*,
keep_on_cpu: bool = False,
):
"""
Defines a field where an element in the batch is obtained by
taking the entirety of the underlying data.
@ -690,6 +746,7 @@ class MultiModalFieldConfig:
modality: The modality of the multi-modal item that uses this
keyword argument.
batch_size: The number of multi-modal items which share this data.
keep_on_cpu: Whether to keep this field on the CPU for the model inputs.
Example:
@ -708,18 +765,15 @@ class MultiModalFieldConfig:
```
"""
return MultiModalFieldConfig(
field=MultiModalSharedField(batch_size),
field=MultiModalSharedField(
batch_size=batch_size,
keep_on_cpu=keep_on_cpu,
),
modality=modality,
)
def __init__(self, field: BaseMultiModalField, modality: str) -> None:
super().__init__()
self.field = field
self.modality = modality
def __repr__(self) -> str:
return f"MultiModalFieldConfig(field={self.field}, modality={self.modality})"
field: BaseMultiModalField
modality: str
def build_elems(
self,
@ -744,7 +798,7 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
modality=modality,
key="dummy",
data=torch.empty(nbytes, dtype=torch.uint8),
field=MultiModalSharedField(1),
field=MultiModalSharedField(batch_size=1),
)
return MultiModalKwargsItem.from_elems([mm_elem])
@ -844,7 +898,6 @@ class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
*,
device: torch.types.Device = None,
pin_memory: bool = False,
cpu_fields: Set[str] = frozenset(),
) -> BatchedTensorInputs:
"""Construct a dictionary of keyword arguments to pass to the model."""
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
@ -859,21 +912,14 @@ class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
elems_by_key[key].append(elem)
data = {
key: elems[0].field.reduce_data(elems, pin_memory=pin_memory)
key: elems[0].field.reduce_data(
elems,
device=device,
pin_memory=pin_memory,
)
for key, elems in elems_by_key.items()
}
if device is not None:
for k in data.keys() - cpu_fields:
data[k] = json_map_leaves(
(
lambda x: x.to(device=device, non_blocking=True)
if isinstance(x, torch.Tensor)
else x
),
data[k],
)
return data

View File

@ -413,7 +413,7 @@ def group_mm_kwargs_by_modality(
device: torch.types.Device = None,
pin_memory: bool = False,
merge_by_field_config: bool | None = None,
multimodal_cpu_fields: Set[str] = frozenset(),
multimodal_cpu_fields: Set[str] | None = None,
) -> Generator[tuple[str, int, BatchedTensorInputs], None, None]:
"""Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
modality together into the same `MultiModalKwargs` instance.
@ -431,6 +431,11 @@ def group_mm_kwargs_by_modality(
"The `merge_by_field_config` argument of `group_mm_kwargs_by_modality` "
"is deprecated and will be removed in v0.13."
)
if multimodal_cpu_fields is not None:
logger.warning_once(
"The `multimodal_cpu_fields` argument of `group_mm_kwargs_by_modality` "
"is deprecated and will be removed in v0.13."
)
from vllm.multimodal.inputs import MultiModalKwargsItems
@ -440,7 +445,6 @@ def group_mm_kwargs_by_modality(
mm_kwargs_data = mm_kwargs_items.get_data(
device=device,
pin_memory=pin_memory,
cpu_fields=multimodal_cpu_fields,
)
yield modality, len(items_lst), mm_kwargs_data

View File

@ -269,10 +269,11 @@ class MsgpackEncoder:
name = MMF_CLASS_TO_FACTORY.get(field.__class__)
if not name:
raise TypeError(f"Unsupported field type: {field.__class__}")
# We just need to copy all of the field values in order
# which will be then used to reconstruct the field.
field_values = (getattr(field, f.name) for f in dataclasses.fields(field))
return name, *field_values
factory_kw = {f.name: getattr(field, f.name) for f in dataclasses.fields(field)}
return name, factory_kw
class MsgpackDecoder:
@ -392,15 +393,15 @@ class MsgpackDecoder:
obj["data"] = self._decode_nested_tensors(obj["data"])
# Reconstruct the field processor using MultiModalFieldConfig
factory_meth_name, *field_args = obj["field"]
factory_meth_name, factory_kw = obj["field"]
factory_meth = getattr(MultiModalFieldConfig, factory_meth_name)
# Special case: decode the union "slices" field of
# MultiModalFlatField
if factory_meth_name == "flat":
field_args[0] = self._decode_nested_slices(field_args[0])
factory_kw["slices"] = self._decode_nested_slices(factory_kw["slices"])
obj["field"] = factory_meth(None, *field_args).field
obj["field"] = factory_meth("", **factory_kw).field
return MultiModalFieldElem(**obj)
def _decode_nested_tensors(self, obj: Any) -> NestedTensors:

View File

@ -1097,7 +1097,6 @@ class GPUModelRunner(
device=self.device,
pin_memory=self.pin_memory,
merge_by_field_config=model.merge_by_field_config,
multimodal_cpu_fields=model.multimodal_cpu_fields,
):
mm_kwargs_combined.update(mm_kwargs_group)
@ -2109,7 +2108,6 @@ class GPUModelRunner(
mm_kwargs,
device=self.device,
pin_memory=self.pin_memory,
multimodal_cpu_fields=model.multimodal_cpu_fields,
):
curr_group_outputs: list[torch.Tensor] = []
@ -2135,7 +2133,6 @@ class GPUModelRunner(
[video_mm_kwargs_item],
device=self.device,
pin_memory=self.pin_memory,
multimodal_cpu_fields=model.multimodal_cpu_fields,
)
)
@ -3887,14 +3884,12 @@ class GPUModelRunner(
dummy_mm_item = dummy_mm_data[modality][0]
dummy_mm_items = [dummy_mm_item] * max_items_per_batch
model = cast(SupportsMultiModal, self.model)
return next(
mm_kwargs_group
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
dummy_mm_items,
device=self.device,
pin_memory=self.pin_memory,
multimodal_cpu_fields=model.multimodal_cpu_fields,
)
)

View File

@ -969,7 +969,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mm_kwargs,
device=self.device,
pin_memory=self.pin_memory,
multimodal_cpu_fields=model.multimodal_cpu_fields,
):
# Run the encoder.
# `curr_group_outputs` is either of the following:
@ -2050,14 +2049,12 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dummy_mm_item = dummy_mm_data[modality][0]
dummy_mm_items = [dummy_mm_item] * max_items_per_batch
model = cast(SupportsMultiModal, self.model)
return next(
grouped_mm_kwargs
for _, _, grouped_mm_kwargs in group_mm_kwargs_by_modality(
dummy_mm_items,
device=self.device,
pin_memory=self.pin_memory,
multimodal_cpu_fields=model.multimodal_cpu_fields,
)
)