[Core][MM] Add mechanism to configure multimodal fields which should stay on CPU (#28168)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
Lukas Geiger 2025-11-07 12:14:29 +00:00 committed by GitHub
parent 8e19d470af
commit e0919f331d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 68 additions and 74 deletions

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Iterable, Mapping, MutableSequence from collections.abc import Callable, Iterable, Mapping, MutableSequence, Set
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
ClassVar, ClassVar,
@ -81,6 +81,11 @@ class SupportsMultiModal(Protocol):
`vllm.multimodal.utils.group_mm_kwargs_by_modality` to use. `vllm.multimodal.utils.group_mm_kwargs_by_modality` to use.
""" """
multimodal_cpu_fields: ClassVar[Set[str]] = frozenset()
"""
A set indicating CPU-only multimodal fields.
"""
@classmethod @classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None: def get_placeholder_str(cls, modality: str, i: int) -> str | None:
""" """

View File

@ -1090,6 +1090,7 @@ class Qwen2_5_VLForConditionalGeneration(
SupportsMRoPE, SupportsMRoPE,
): ):
merge_by_field_config = True merge_by_field_config = True
multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"}
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"], "qkv_proj": ["q_proj", "k_proj", "v_proj"],
@ -1364,13 +1365,8 @@ class Qwen2_5_VLForConditionalGeneration(
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list) image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
# Split concatenated embeddings for each image item. # Split concatenated embeddings for each image item.
# Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
merge_size = self.visual.spatial_merge_size merge_size = self.visual.spatial_merge_size
sizes = ( sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
// (merge_size * merge_size)
).tolist()
return image_embeds.split(sizes) return image_embeds.split(sizes)
def _postprocess_image_embeds_evs( def _postprocess_image_embeds_evs(
@ -1430,12 +1426,7 @@ class Qwen2_5_VLForConditionalGeneration(
# Split concatenated embeddings for each video item. # Split concatenated embeddings for each video item.
merge_size = self.visual.spatial_merge_size merge_size = self.visual.spatial_merge_size
# Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
sizes = (
torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
// (merge_size * merge_size)
).tolist()
return video_embeds.split(sizes) return video_embeds.split(sizes)
def _postprocess_video_embeds_evs( def _postprocess_video_embeds_evs(

View File

@ -798,21 +798,27 @@ class Qwen2VisionTransformer(nn.Module):
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
grid_thw: list[list[int]], grid_thw: torch.Tensor | list[list[int]],
) -> torch.Tensor: ) -> torch.Tensor:
# patchify # patchify
x = x.to(device=self.device, dtype=self.dtype) x = x.to(device=self.device, dtype=self.dtype)
x = self.patch_embed(x) x = self.patch_embed(x)
if isinstance(grid_thw, list):
grid_thw_list = grid_thw
grid_thw = torch.tensor(grid_thw, dtype=torch.int32)
else:
grid_thw_list = grid_thw.tolist()
# compute position embedding # compute position embedding
rotary_pos_emb = self.rot_pos_emb(grid_thw) rotary_pos_emb = self.rot_pos_emb(grid_thw_list)
# compute cu_seqlens # compute cu_seqlens
grid_thw_ = torch.tensor(grid_thw, device=x.device, dtype=torch.long)
cu_seqlens = torch.repeat_interleave( cu_seqlens = torch.repeat_interleave(
grid_thw_[:, 1] * grid_thw_[:, 2], grid_thw_[:, 0] grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum(dim=0, dtype=torch.int32) ).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
# transformers # transformers
x = x.unsqueeze(1) x = x.unsqueeze(1)
@ -1211,6 +1217,7 @@ class Qwen2VLForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
): ):
merge_by_field_config = True merge_by_field_config = True
multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"}
# To ensure correct weight loading and mapping. # To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
@ -1458,7 +1465,6 @@ class Qwen2VLForConditionalGeneration(
) -> tuple[torch.Tensor, ...]: ) -> tuple[torch.Tensor, ...]:
grid_thw = image_input["image_grid_thw"] grid_thw = image_input["image_grid_thw"]
assert grid_thw.ndim == 2 assert grid_thw.ndim == 2
grid_thw_list = grid_thw.tolist()
if image_input["type"] == "image_embeds": if image_input["type"] == "image_embeds":
image_embeds = image_input["image_embeds"] image_embeds = image_input["image_embeds"]
@ -1467,18 +1473,14 @@ class Qwen2VLForConditionalGeneration(
if self.use_data_parallel: if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model( return run_dp_sharded_mrope_vision_model(
self.visual, pixel_values, grid_thw_list, rope_type="rope_3d" self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d"
) )
else: else:
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list) image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
# Split concatenated embeddings for each image item. # Split concatenated embeddings for each image item.
merge_size = self.visual.spatial_merge_size merge_size = self.visual.spatial_merge_size
sizes = ( sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
// (merge_size * merge_size)
).tolist()
return image_embeds.split(sizes) return image_embeds.split(sizes)
def _process_video_input( def _process_video_input(
@ -1486,26 +1488,22 @@ class Qwen2VLForConditionalGeneration(
) -> tuple[torch.Tensor, ...]: ) -> tuple[torch.Tensor, ...]:
grid_thw = video_input["video_grid_thw"] grid_thw = video_input["video_grid_thw"]
assert grid_thw.ndim == 2 assert grid_thw.ndim == 2
grid_thw_list = grid_thw.tolist()
if video_input["type"] == "video_embeds": if video_input["type"] == "video_embeds":
video_embeds = video_input["video_embeds"] video_embeds = video_input["video_embeds"]
else: else:
pixel_values_videos = video_input["pixel_values_videos"] pixel_values_videos = video_input["pixel_values_videos"]
if self.use_data_parallel: if self.use_data_parallel:
grid_thw_list = grid_thw.tolist()
return run_dp_sharded_mrope_vision_model( return run_dp_sharded_mrope_vision_model(
self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d" self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d"
) )
else: else:
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw_list) video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
# Split concatenated embeddings for each video item. # Split concatenated embeddings for each video item.
merge_size = self.visual.spatial_merge_size merge_size = self.visual.spatial_merge_size
sizes = ( sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
// (merge_size * merge_size)
).tolist()
return video_embeds.split(sizes) return video_embeds.split(sizes)
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:

View File

@ -414,16 +414,10 @@ class Qwen3_VisionTransformer(nn.Module):
def device(self) -> torch.device: def device(self) -> torch.device:
return self.patch_embed.proj.weight.device return self.patch_embed.proj.weight.device
def rot_pos_emb(self, grid_thw): def rot_pos_emb(self, grid_thw: list[list[int]]):
pos_ids = [] pos_ids = []
# Support both Tensor and list inputs for DP path max_grid_size = max(max(h, w) for _, h, w in grid_thw)
if isinstance(grid_thw, list): for t, h, w in grid_thw:
grid_list = grid_thw
max_grid_size = max(max(h, w) for _, h, w in grid_list)
else:
grid_list = grid_thw.tolist()
max_grid_size = int(grid_thw[:, 1:].max().item())
for t, h, w in grid_list:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.reshape( hpos_ids = hpos_ids.reshape(
h // self.spatial_merge_size, h // self.spatial_merge_size,
@ -527,24 +521,25 @@ class Qwen3_VisionTransformer(nn.Module):
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
grid_thw: list[list[int]], grid_thw: torch.Tensor | list[list[int]],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = x.to(device=self.device, dtype=self.dtype, non_blocking=True) hidden_states = x.to(device=self.device, dtype=self.dtype, non_blocking=True)
hidden_states = self.patch_embed(hidden_states) hidden_states = self.patch_embed(hidden_states)
pos_embeds = self.fast_pos_embed_interpolate(grid_thw) if isinstance(grid_thw, list):
grid_thw_list = grid_thw
grid_thw = torch.tensor(grid_thw, dtype=torch.int32)
else:
grid_thw_list = grid_thw.tolist()
pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list)
hidden_states = hidden_states + pos_embeds hidden_states = hidden_states + pos_embeds
rotary_pos_emb = self.rot_pos_emb(grid_thw) rotary_pos_emb = self.rot_pos_emb(grid_thw_list)
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, non_blocking=True) rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, non_blocking=True)
grid_thw_tensor = torch.tensor(grid_thw, dtype=torch.int32)
cu_seqlens = torch.repeat_interleave( cu_seqlens = torch.repeat_interleave(
grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2], grid_thw_tensor[:, 0] grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum( ).cumsum(dim=0, dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32)
dim=0,
dtype=grid_thw_tensor.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens]) cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
hidden_states = hidden_states.unsqueeze(1) hidden_states = hidden_states.unsqueeze(1)
@ -1177,6 +1172,7 @@ class Qwen3VLForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
): ):
merge_by_field_config = True merge_by_field_config = True
multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"}
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
@ -1356,7 +1352,6 @@ class Qwen3VLForConditionalGeneration(
) -> tuple[torch.Tensor, ...]: ) -> tuple[torch.Tensor, ...]:
grid_thw = image_input["image_grid_thw"] grid_thw = image_input["image_grid_thw"]
assert grid_thw.ndim == 2 assert grid_thw.ndim == 2
grid_thw_list = grid_thw.tolist()
if image_input["type"] == "image_embeds": if image_input["type"] == "image_embeds":
image_embeds = image_input["image_embeds"].type(self.visual.dtype) image_embeds = image_input["image_embeds"].type(self.visual.dtype)
@ -1364,18 +1359,14 @@ class Qwen3VLForConditionalGeneration(
pixel_values = image_input["pixel_values"].type(self.visual.dtype) pixel_values = image_input["pixel_values"].type(self.visual.dtype)
if self.use_data_parallel: if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model( return run_dp_sharded_mrope_vision_model(
self.visual, pixel_values, grid_thw_list, rope_type="rope_3d" self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d"
) )
else: else:
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list) image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
# Split concatenated embeddings for each image item. # Split concatenated embeddings for each image item.
# Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
merge_size = self.visual.spatial_merge_size merge_size = self.visual.spatial_merge_size
sizes = ( sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
// (merge_size * merge_size)
).tolist()
return image_embeds.split(sizes) return image_embeds.split(sizes)
def _process_video_input( def _process_video_input(
@ -1383,7 +1374,6 @@ class Qwen3VLForConditionalGeneration(
) -> tuple[torch.Tensor, ...]: ) -> tuple[torch.Tensor, ...]:
grid_thw = video_input["video_grid_thw"] grid_thw = video_input["video_grid_thw"]
assert grid_thw.ndim == 2 assert grid_thw.ndim == 2
grid_thw_list = grid_thw.tolist()
if video_input["type"] == "video_embeds": if video_input["type"] == "video_embeds":
video_embeds = video_input["video_embeds"].type(self.visual.dtype) video_embeds = video_input["video_embeds"].type(self.visual.dtype)
@ -1392,19 +1382,16 @@ class Qwen3VLForConditionalGeneration(
self.visual.dtype self.visual.dtype
) )
if self.use_data_parallel: if self.use_data_parallel:
grid_thw_list = grid_thw.tolist()
return run_dp_sharded_mrope_vision_model( return run_dp_sharded_mrope_vision_model(
self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d" self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d"
) )
else: else:
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw_list) video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
# Split concatenated embeddings for each video item. # Split concatenated embeddings for each video item.
# Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
merge_size = self.visual.spatial_merge_size merge_size = self.visual.spatial_merge_size
sizes = ( sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
// (merge_size * merge_size)
).tolist()
return video_embeds.split(sizes) return video_embeds.split(sizes)
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:

View File

@ -3,7 +3,7 @@
import asyncio import asyncio
import atexit import atexit
from collections.abc import Iterable from collections.abc import Iterable, Set
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from itertools import groupby from itertools import groupby
from pathlib import Path from pathlib import Path
@ -402,6 +402,7 @@ def group_mm_kwargs_by_modality(
device: torch.types.Device = None, device: torch.types.Device = None,
pin_memory: bool = False, pin_memory: bool = False,
merge_by_field_config: bool | None = None, merge_by_field_config: bool | None = None,
multimodal_cpu_fields: Set[str] = frozenset(),
) -> Iterable[tuple[str, int, BatchedTensorInputs]]: ) -> Iterable[tuple[str, int, BatchedTensorInputs]]:
"""Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same """Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
modality together into the same `MultiModalKwargs` instance. modality together into the same `MultiModalKwargs` instance.
@ -443,12 +444,17 @@ def group_mm_kwargs_by_modality(
) )
if device is not None: if device is not None:
mm_kwargs_group = json_map_leaves( mm_kwargs_group = {
lambda x: x.to(device=device, non_blocking=True) k: json_map_leaves(
if isinstance(x, torch.Tensor) lambda x: x.to(device=device, non_blocking=True)
else x, if isinstance(x, torch.Tensor)
mm_kwargs_group, else x,
) v,
)
if k not in multimodal_cpu_fields
else v
for k, v in mm_kwargs_group.items()
}
else: else:
mm_kwargs_group = MultiModalKwargs.as_kwargs( mm_kwargs_group = MultiModalKwargs.as_kwargs(
MultiModalKwargs.batch( MultiModalKwargs.batch(

View File

@ -938,6 +938,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
merge_by_field_config=model.merge_by_field_config, merge_by_field_config=model.merge_by_field_config,
multimodal_cpu_fields=model.multimodal_cpu_fields,
): ):
mm_kwargs_combined.update(mm_kwargs_group) mm_kwargs_combined.update(mm_kwargs_group)
@ -1768,6 +1769,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
merge_by_field_config=model.merge_by_field_config, merge_by_field_config=model.merge_by_field_config,
multimodal_cpu_fields=model.multimodal_cpu_fields,
): ):
curr_group_outputs = [] curr_group_outputs = []
@ -1794,6 +1796,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
merge_by_field_config=model.merge_by_field_config, merge_by_field_config=model.merge_by_field_config,
multimodal_cpu_fields=model.multimodal_cpu_fields,
) )
) )
@ -1936,6 +1939,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
merge_by_field_config=model.merge_by_field_config, merge_by_field_config=model.merge_by_field_config,
multimodal_cpu_fields=model.multimodal_cpu_fields,
): ):
# Add the grouped features to encoder_features dict # Add the grouped features to encoder_features dict
# This allows the model to receive them as kwargs (e.g., # This allows the model to receive them as kwargs (e.g.,
@ -3292,6 +3296,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
merge_by_field_config=model.merge_by_field_config, merge_by_field_config=model.merge_by_field_config,
multimodal_cpu_fields=model.multimodal_cpu_fields,
) )
) )

View File

@ -952,6 +952,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
merge_by_field_config=model.merge_by_field_config, merge_by_field_config=model.merge_by_field_config,
multimodal_cpu_fields=model.multimodal_cpu_fields,
): ):
# Run the encoder. # Run the encoder.
# `curr_group_outputs` is either of the following: # `curr_group_outputs` is either of the following:
@ -2037,6 +2038,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
merge_by_field_config=model.merge_by_field_config, merge_by_field_config=model.merge_by_field_config,
multimodal_cpu_fields=model.multimodal_cpu_fields,
) )
) )