diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 33c9043405ca..b634c7ec7d67 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Callable, Iterable, Mapping, MutableSequence +from collections.abc import Callable, Iterable, Mapping, MutableSequence, Set from typing import ( TYPE_CHECKING, ClassVar, @@ -81,6 +81,11 @@ class SupportsMultiModal(Protocol): `vllm.multimodal.utils.group_mm_kwargs_by_modality` to use. """ + multimodal_cpu_fields: ClassVar[Set[str]] = frozenset() + """ + A set indicating CPU-only multimodal fields. + """ + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: """ diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index a90cfe96414b..d337f1606943 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -1090,6 +1090,7 @@ class Qwen2_5_VLForConditionalGeneration( SupportsMRoPE, ): merge_by_field_config = True + multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"} packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], @@ -1364,13 +1365,8 @@ class Qwen2_5_VLForConditionalGeneration( image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list) # Split concatenated embeddings for each image item. - # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync merge_size = self.visual.spatial_merge_size - sizes = ( - torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) - // (merge_size * merge_size) - ).tolist() - + sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist() return image_embeds.split(sizes) def _postprocess_image_embeds_evs( @@ -1430,12 +1426,7 @@ class Qwen2_5_VLForConditionalGeneration( # Split concatenated embeddings for each video item. merge_size = self.visual.spatial_merge_size - # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync - sizes = ( - torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) - // (merge_size * merge_size) - ).tolist() - + sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist() return video_embeds.split(sizes) def _postprocess_video_embeds_evs( diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 1ec12bdb55df..9206ac8f9d03 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -798,21 +798,27 @@ class Qwen2VisionTransformer(nn.Module): def forward( self, x: torch.Tensor, - grid_thw: list[list[int]], + grid_thw: torch.Tensor | list[list[int]], ) -> torch.Tensor: # patchify x = x.to(device=self.device, dtype=self.dtype) x = self.patch_embed(x) + if isinstance(grid_thw, list): + grid_thw_list = grid_thw + grid_thw = torch.tensor(grid_thw, dtype=torch.int32) + else: + grid_thw_list = grid_thw.tolist() + # compute position embedding - rotary_pos_emb = self.rot_pos_emb(grid_thw) + rotary_pos_emb = self.rot_pos_emb(grid_thw_list) # compute cu_seqlens - grid_thw_ = torch.tensor(grid_thw, device=x.device, dtype=torch.long) cu_seqlens = torch.repeat_interleave( - grid_thw_[:, 1] * grid_thw_[:, 2], grid_thw_[:, 0] + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] ).cumsum(dim=0, dtype=torch.int32) - cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) + cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens]) + cu_seqlens = cu_seqlens.to(self.device, non_blocking=True) # transformers x = x.unsqueeze(1) @@ -1211,6 +1217,7 @@ class Qwen2VLForConditionalGeneration( nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE ): merge_by_field_config = True + multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"} # To ensure correct weight loading and mapping. hf_to_vllm_mapper = WeightsMapper( @@ -1458,7 +1465,6 @@ class Qwen2VLForConditionalGeneration( ) -> tuple[torch.Tensor, ...]: grid_thw = image_input["image_grid_thw"] assert grid_thw.ndim == 2 - grid_thw_list = grid_thw.tolist() if image_input["type"] == "image_embeds": image_embeds = image_input["image_embeds"] @@ -1467,18 +1473,14 @@ class Qwen2VLForConditionalGeneration( if self.use_data_parallel: return run_dp_sharded_mrope_vision_model( - self.visual, pixel_values, grid_thw_list, rope_type="rope_3d" + self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d" ) else: - image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list) + image_embeds = self.visual(pixel_values, grid_thw=grid_thw) # Split concatenated embeddings for each image item. merge_size = self.visual.spatial_merge_size - sizes = ( - torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) - // (merge_size * merge_size) - ).tolist() - + sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist() return image_embeds.split(sizes) def _process_video_input( @@ -1486,26 +1488,22 @@ class Qwen2VLForConditionalGeneration( ) -> tuple[torch.Tensor, ...]: grid_thw = video_input["video_grid_thw"] assert grid_thw.ndim == 2 - grid_thw_list = grid_thw.tolist() if video_input["type"] == "video_embeds": video_embeds = video_input["video_embeds"] else: pixel_values_videos = video_input["pixel_values_videos"] if self.use_data_parallel: + grid_thw_list = grid_thw.tolist() return run_dp_sharded_mrope_vision_model( self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d" ) else: - video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw_list) + video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) # Split concatenated embeddings for each video item. merge_size = self.visual.spatial_merge_size - sizes = ( - torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) - // (merge_size * merge_size) - ).tolist() - + sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist() return video_embeds.split(sizes) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index d611580c7182..2d8f431bb8fa 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -414,16 +414,10 @@ class Qwen3_VisionTransformer(nn.Module): def device(self) -> torch.device: return self.patch_embed.proj.weight.device - def rot_pos_emb(self, grid_thw): + def rot_pos_emb(self, grid_thw: list[list[int]]): pos_ids = [] - # Support both Tensor and list inputs for DP path - if isinstance(grid_thw, list): - grid_list = grid_thw - max_grid_size = max(max(h, w) for _, h, w in grid_list) - else: - grid_list = grid_thw.tolist() - max_grid_size = int(grid_thw[:, 1:].max().item()) - for t, h, w in grid_list: + max_grid_size = max(max(h, w) for _, h, w in grid_thw) + for t, h, w in grid_thw: hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) hpos_ids = hpos_ids.reshape( h // self.spatial_merge_size, @@ -527,24 +521,25 @@ class Qwen3_VisionTransformer(nn.Module): def forward( self, x: torch.Tensor, - grid_thw: list[list[int]], + grid_thw: torch.Tensor | list[list[int]], ) -> torch.Tensor: hidden_states = x.to(device=self.device, dtype=self.dtype, non_blocking=True) hidden_states = self.patch_embed(hidden_states) - pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + if isinstance(grid_thw, list): + grid_thw_list = grid_thw + grid_thw = torch.tensor(grid_thw, dtype=torch.int32) + else: + grid_thw_list = grid_thw.tolist() + + pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list) hidden_states = hidden_states + pos_embeds - rotary_pos_emb = self.rot_pos_emb(grid_thw) + rotary_pos_emb = self.rot_pos_emb(grid_thw_list) rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, non_blocking=True) - grid_thw_tensor = torch.tensor(grid_thw, dtype=torch.int32) - cu_seqlens = torch.repeat_interleave( - grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2], grid_thw_tensor[:, 0] - ).cumsum( - dim=0, - dtype=grid_thw_tensor.dtype if torch.jit.is_tracing() else torch.int32, - ) + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum(dim=0, dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32) cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens]) hidden_states = hidden_states.unsqueeze(1) @@ -1177,6 +1172,7 @@ class Qwen3VLForConditionalGeneration( nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE ): merge_by_field_config = True + multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"} packed_modules_mapping = { "qkv_proj": [ @@ -1356,7 +1352,6 @@ class Qwen3VLForConditionalGeneration( ) -> tuple[torch.Tensor, ...]: grid_thw = image_input["image_grid_thw"] assert grid_thw.ndim == 2 - grid_thw_list = grid_thw.tolist() if image_input["type"] == "image_embeds": image_embeds = image_input["image_embeds"].type(self.visual.dtype) @@ -1364,18 +1359,14 @@ class Qwen3VLForConditionalGeneration( pixel_values = image_input["pixel_values"].type(self.visual.dtype) if self.use_data_parallel: return run_dp_sharded_mrope_vision_model( - self.visual, pixel_values, grid_thw_list, rope_type="rope_3d" + self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d" ) else: - image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list) + image_embeds = self.visual(pixel_values, grid_thw=grid_thw) # Split concatenated embeddings for each image item. - # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync merge_size = self.visual.spatial_merge_size - sizes = ( - torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) - // (merge_size * merge_size) - ).tolist() + sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist() return image_embeds.split(sizes) def _process_video_input( @@ -1383,7 +1374,6 @@ class Qwen3VLForConditionalGeneration( ) -> tuple[torch.Tensor, ...]: grid_thw = video_input["video_grid_thw"] assert grid_thw.ndim == 2 - grid_thw_list = grid_thw.tolist() if video_input["type"] == "video_embeds": video_embeds = video_input["video_embeds"].type(self.visual.dtype) @@ -1392,19 +1382,16 @@ class Qwen3VLForConditionalGeneration( self.visual.dtype ) if self.use_data_parallel: + grid_thw_list = grid_thw.tolist() return run_dp_sharded_mrope_vision_model( self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d" ) else: - video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw_list) + video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) # Split concatenated embeddings for each video item. - # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync merge_size = self.visual.spatial_merge_size - sizes = ( - torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) - // (merge_size * merge_size) - ).tolist() + sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist() return video_embeds.split(sizes) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index aa61bcc11f9f..3f55c46ca334 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -3,7 +3,7 @@ import asyncio import atexit -from collections.abc import Iterable +from collections.abc import Iterable, Set from concurrent.futures import ThreadPoolExecutor from itertools import groupby from pathlib import Path @@ -402,6 +402,7 @@ def group_mm_kwargs_by_modality( device: torch.types.Device = None, pin_memory: bool = False, merge_by_field_config: bool | None = None, + multimodal_cpu_fields: Set[str] = frozenset(), ) -> Iterable[tuple[str, int, BatchedTensorInputs]]: """Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same modality together into the same `MultiModalKwargs` instance. @@ -443,12 +444,17 @@ def group_mm_kwargs_by_modality( ) if device is not None: - mm_kwargs_group = json_map_leaves( - lambda x: x.to(device=device, non_blocking=True) - if isinstance(x, torch.Tensor) - else x, - mm_kwargs_group, - ) + mm_kwargs_group = { + k: json_map_leaves( + lambda x: x.to(device=device, non_blocking=True) + if isinstance(x, torch.Tensor) + else x, + v, + ) + if k not in multimodal_cpu_fields + else v + for k, v in mm_kwargs_group.items() + } else: mm_kwargs_group = MultiModalKwargs.as_kwargs( MultiModalKwargs.batch( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 91015ad4379c..91c8efc17feb 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -938,6 +938,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): device=self.device, pin_memory=self.pin_memory, merge_by_field_config=model.merge_by_field_config, + multimodal_cpu_fields=model.multimodal_cpu_fields, ): mm_kwargs_combined.update(mm_kwargs_group) @@ -1768,6 +1769,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): device=self.device, pin_memory=self.pin_memory, merge_by_field_config=model.merge_by_field_config, + multimodal_cpu_fields=model.multimodal_cpu_fields, ): curr_group_outputs = [] @@ -1794,6 +1796,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): device=self.device, pin_memory=self.pin_memory, merge_by_field_config=model.merge_by_field_config, + multimodal_cpu_fields=model.multimodal_cpu_fields, ) ) @@ -1936,6 +1939,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): device=self.device, pin_memory=self.pin_memory, merge_by_field_config=model.merge_by_field_config, + multimodal_cpu_fields=model.multimodal_cpu_fields, ): # Add the grouped features to encoder_features dict # This allows the model to receive them as kwargs (e.g., @@ -3292,6 +3296,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): device=self.device, pin_memory=self.pin_memory, merge_by_field_config=model.merge_by_field_config, + multimodal_cpu_fields=model.multimodal_cpu_fields, ) ) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 0e34504a5e26..26816ce0f209 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -952,6 +952,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): device=self.device, pin_memory=self.pin_memory, merge_by_field_config=model.merge_by_field_config, + multimodal_cpu_fields=model.multimodal_cpu_fields, ): # Run the encoder. # `curr_group_outputs` is either of the following: @@ -2037,6 +2038,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): device=self.device, pin_memory=self.pin_memory, merge_by_field_config=model.merge_by_field_config, + multimodal_cpu_fields=model.multimodal_cpu_fields, ) )