mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 13:15:48 +08:00
[Core][MM] Add mechanism to configure multimodal fields which should stay on CPU (#28168)
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
parent
8e19d470af
commit
e0919f331d
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from collections.abc import Callable, Iterable, Mapping, MutableSequence
|
from collections.abc import Callable, Iterable, Mapping, MutableSequence, Set
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
ClassVar,
|
ClassVar,
|
||||||
@ -81,6 +81,11 @@ class SupportsMultiModal(Protocol):
|
|||||||
`vllm.multimodal.utils.group_mm_kwargs_by_modality` to use.
|
`vllm.multimodal.utils.group_mm_kwargs_by_modality` to use.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
multimodal_cpu_fields: ClassVar[Set[str]] = frozenset()
|
||||||
|
"""
|
||||||
|
A set indicating CPU-only multimodal fields.
|
||||||
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1090,6 +1090,7 @@ class Qwen2_5_VLForConditionalGeneration(
|
|||||||
SupportsMRoPE,
|
SupportsMRoPE,
|
||||||
):
|
):
|
||||||
merge_by_field_config = True
|
merge_by_field_config = True
|
||||||
|
multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"}
|
||||||
|
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||||
@ -1364,13 +1365,8 @@ class Qwen2_5_VLForConditionalGeneration(
|
|||||||
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
|
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
|
||||||
|
|
||||||
# Split concatenated embeddings for each image item.
|
# Split concatenated embeddings for each image item.
|
||||||
# Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
|
|
||||||
merge_size = self.visual.spatial_merge_size
|
merge_size = self.visual.spatial_merge_size
|
||||||
sizes = (
|
sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
|
||||||
torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
|
|
||||||
// (merge_size * merge_size)
|
|
||||||
).tolist()
|
|
||||||
|
|
||||||
return image_embeds.split(sizes)
|
return image_embeds.split(sizes)
|
||||||
|
|
||||||
def _postprocess_image_embeds_evs(
|
def _postprocess_image_embeds_evs(
|
||||||
@ -1430,12 +1426,7 @@ class Qwen2_5_VLForConditionalGeneration(
|
|||||||
|
|
||||||
# Split concatenated embeddings for each video item.
|
# Split concatenated embeddings for each video item.
|
||||||
merge_size = self.visual.spatial_merge_size
|
merge_size = self.visual.spatial_merge_size
|
||||||
# Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
|
sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
|
||||||
sizes = (
|
|
||||||
torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
|
|
||||||
// (merge_size * merge_size)
|
|
||||||
).tolist()
|
|
||||||
|
|
||||||
return video_embeds.split(sizes)
|
return video_embeds.split(sizes)
|
||||||
|
|
||||||
def _postprocess_video_embeds_evs(
|
def _postprocess_video_embeds_evs(
|
||||||
|
|||||||
@ -798,21 +798,27 @@ class Qwen2VisionTransformer(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
grid_thw: list[list[int]],
|
grid_thw: torch.Tensor | list[list[int]],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# patchify
|
# patchify
|
||||||
x = x.to(device=self.device, dtype=self.dtype)
|
x = x.to(device=self.device, dtype=self.dtype)
|
||||||
x = self.patch_embed(x)
|
x = self.patch_embed(x)
|
||||||
|
|
||||||
|
if isinstance(grid_thw, list):
|
||||||
|
grid_thw_list = grid_thw
|
||||||
|
grid_thw = torch.tensor(grid_thw, dtype=torch.int32)
|
||||||
|
else:
|
||||||
|
grid_thw_list = grid_thw.tolist()
|
||||||
|
|
||||||
# compute position embedding
|
# compute position embedding
|
||||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
rotary_pos_emb = self.rot_pos_emb(grid_thw_list)
|
||||||
|
|
||||||
# compute cu_seqlens
|
# compute cu_seqlens
|
||||||
grid_thw_ = torch.tensor(grid_thw, device=x.device, dtype=torch.long)
|
|
||||||
cu_seqlens = torch.repeat_interleave(
|
cu_seqlens = torch.repeat_interleave(
|
||||||
grid_thw_[:, 1] * grid_thw_[:, 2], grid_thw_[:, 0]
|
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||||
).cumsum(dim=0, dtype=torch.int32)
|
).cumsum(dim=0, dtype=torch.int32)
|
||||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
|
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
|
||||||
|
cu_seqlens = cu_seqlens.to(self.device, non_blocking=True)
|
||||||
|
|
||||||
# transformers
|
# transformers
|
||||||
x = x.unsqueeze(1)
|
x = x.unsqueeze(1)
|
||||||
@ -1211,6 +1217,7 @@ class Qwen2VLForConditionalGeneration(
|
|||||||
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
|
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
|
||||||
):
|
):
|
||||||
merge_by_field_config = True
|
merge_by_field_config = True
|
||||||
|
multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"}
|
||||||
|
|
||||||
# To ensure correct weight loading and mapping.
|
# To ensure correct weight loading and mapping.
|
||||||
hf_to_vllm_mapper = WeightsMapper(
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
@ -1458,7 +1465,6 @@ class Qwen2VLForConditionalGeneration(
|
|||||||
) -> tuple[torch.Tensor, ...]:
|
) -> tuple[torch.Tensor, ...]:
|
||||||
grid_thw = image_input["image_grid_thw"]
|
grid_thw = image_input["image_grid_thw"]
|
||||||
assert grid_thw.ndim == 2
|
assert grid_thw.ndim == 2
|
||||||
grid_thw_list = grid_thw.tolist()
|
|
||||||
|
|
||||||
if image_input["type"] == "image_embeds":
|
if image_input["type"] == "image_embeds":
|
||||||
image_embeds = image_input["image_embeds"]
|
image_embeds = image_input["image_embeds"]
|
||||||
@ -1467,18 +1473,14 @@ class Qwen2VLForConditionalGeneration(
|
|||||||
|
|
||||||
if self.use_data_parallel:
|
if self.use_data_parallel:
|
||||||
return run_dp_sharded_mrope_vision_model(
|
return run_dp_sharded_mrope_vision_model(
|
||||||
self.visual, pixel_values, grid_thw_list, rope_type="rope_3d"
|
self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
|
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
|
||||||
|
|
||||||
# Split concatenated embeddings for each image item.
|
# Split concatenated embeddings for each image item.
|
||||||
merge_size = self.visual.spatial_merge_size
|
merge_size = self.visual.spatial_merge_size
|
||||||
sizes = (
|
sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
|
||||||
torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
|
|
||||||
// (merge_size * merge_size)
|
|
||||||
).tolist()
|
|
||||||
|
|
||||||
return image_embeds.split(sizes)
|
return image_embeds.split(sizes)
|
||||||
|
|
||||||
def _process_video_input(
|
def _process_video_input(
|
||||||
@ -1486,26 +1488,22 @@ class Qwen2VLForConditionalGeneration(
|
|||||||
) -> tuple[torch.Tensor, ...]:
|
) -> tuple[torch.Tensor, ...]:
|
||||||
grid_thw = video_input["video_grid_thw"]
|
grid_thw = video_input["video_grid_thw"]
|
||||||
assert grid_thw.ndim == 2
|
assert grid_thw.ndim == 2
|
||||||
grid_thw_list = grid_thw.tolist()
|
|
||||||
|
|
||||||
if video_input["type"] == "video_embeds":
|
if video_input["type"] == "video_embeds":
|
||||||
video_embeds = video_input["video_embeds"]
|
video_embeds = video_input["video_embeds"]
|
||||||
else:
|
else:
|
||||||
pixel_values_videos = video_input["pixel_values_videos"]
|
pixel_values_videos = video_input["pixel_values_videos"]
|
||||||
if self.use_data_parallel:
|
if self.use_data_parallel:
|
||||||
|
grid_thw_list = grid_thw.tolist()
|
||||||
return run_dp_sharded_mrope_vision_model(
|
return run_dp_sharded_mrope_vision_model(
|
||||||
self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d"
|
self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw_list)
|
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
|
||||||
|
|
||||||
# Split concatenated embeddings for each video item.
|
# Split concatenated embeddings for each video item.
|
||||||
merge_size = self.visual.spatial_merge_size
|
merge_size = self.visual.spatial_merge_size
|
||||||
sizes = (
|
sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
|
||||||
torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
|
|
||||||
// (merge_size * merge_size)
|
|
||||||
).tolist()
|
|
||||||
|
|
||||||
return video_embeds.split(sizes)
|
return video_embeds.split(sizes)
|
||||||
|
|
||||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||||
|
|||||||
@ -414,16 +414,10 @@ class Qwen3_VisionTransformer(nn.Module):
|
|||||||
def device(self) -> torch.device:
|
def device(self) -> torch.device:
|
||||||
return self.patch_embed.proj.weight.device
|
return self.patch_embed.proj.weight.device
|
||||||
|
|
||||||
def rot_pos_emb(self, grid_thw):
|
def rot_pos_emb(self, grid_thw: list[list[int]]):
|
||||||
pos_ids = []
|
pos_ids = []
|
||||||
# Support both Tensor and list inputs for DP path
|
max_grid_size = max(max(h, w) for _, h, w in grid_thw)
|
||||||
if isinstance(grid_thw, list):
|
for t, h, w in grid_thw:
|
||||||
grid_list = grid_thw
|
|
||||||
max_grid_size = max(max(h, w) for _, h, w in grid_list)
|
|
||||||
else:
|
|
||||||
grid_list = grid_thw.tolist()
|
|
||||||
max_grid_size = int(grid_thw[:, 1:].max().item())
|
|
||||||
for t, h, w in grid_list:
|
|
||||||
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
||||||
hpos_ids = hpos_ids.reshape(
|
hpos_ids = hpos_ids.reshape(
|
||||||
h // self.spatial_merge_size,
|
h // self.spatial_merge_size,
|
||||||
@ -527,24 +521,25 @@ class Qwen3_VisionTransformer(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
grid_thw: list[list[int]],
|
grid_thw: torch.Tensor | list[list[int]],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = x.to(device=self.device, dtype=self.dtype, non_blocking=True)
|
hidden_states = x.to(device=self.device, dtype=self.dtype, non_blocking=True)
|
||||||
hidden_states = self.patch_embed(hidden_states)
|
hidden_states = self.patch_embed(hidden_states)
|
||||||
|
|
||||||
pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
|
if isinstance(grid_thw, list):
|
||||||
|
grid_thw_list = grid_thw
|
||||||
|
grid_thw = torch.tensor(grid_thw, dtype=torch.int32)
|
||||||
|
else:
|
||||||
|
grid_thw_list = grid_thw.tolist()
|
||||||
|
|
||||||
|
pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list)
|
||||||
hidden_states = hidden_states + pos_embeds
|
hidden_states = hidden_states + pos_embeds
|
||||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
rotary_pos_emb = self.rot_pos_emb(grid_thw_list)
|
||||||
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, non_blocking=True)
|
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, non_blocking=True)
|
||||||
|
|
||||||
grid_thw_tensor = torch.tensor(grid_thw, dtype=torch.int32)
|
|
||||||
|
|
||||||
cu_seqlens = torch.repeat_interleave(
|
cu_seqlens = torch.repeat_interleave(
|
||||||
grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2], grid_thw_tensor[:, 0]
|
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||||
).cumsum(
|
).cumsum(dim=0, dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32)
|
||||||
dim=0,
|
|
||||||
dtype=grid_thw_tensor.dtype if torch.jit.is_tracing() else torch.int32,
|
|
||||||
)
|
|
||||||
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
|
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
|
||||||
|
|
||||||
hidden_states = hidden_states.unsqueeze(1)
|
hidden_states = hidden_states.unsqueeze(1)
|
||||||
@ -1177,6 +1172,7 @@ class Qwen3VLForConditionalGeneration(
|
|||||||
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
|
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
|
||||||
):
|
):
|
||||||
merge_by_field_config = True
|
merge_by_field_config = True
|
||||||
|
multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"}
|
||||||
|
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": [
|
"qkv_proj": [
|
||||||
@ -1356,7 +1352,6 @@ class Qwen3VLForConditionalGeneration(
|
|||||||
) -> tuple[torch.Tensor, ...]:
|
) -> tuple[torch.Tensor, ...]:
|
||||||
grid_thw = image_input["image_grid_thw"]
|
grid_thw = image_input["image_grid_thw"]
|
||||||
assert grid_thw.ndim == 2
|
assert grid_thw.ndim == 2
|
||||||
grid_thw_list = grid_thw.tolist()
|
|
||||||
|
|
||||||
if image_input["type"] == "image_embeds":
|
if image_input["type"] == "image_embeds":
|
||||||
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
|
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
|
||||||
@ -1364,18 +1359,14 @@ class Qwen3VLForConditionalGeneration(
|
|||||||
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
||||||
if self.use_data_parallel:
|
if self.use_data_parallel:
|
||||||
return run_dp_sharded_mrope_vision_model(
|
return run_dp_sharded_mrope_vision_model(
|
||||||
self.visual, pixel_values, grid_thw_list, rope_type="rope_3d"
|
self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
|
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
|
||||||
|
|
||||||
# Split concatenated embeddings for each image item.
|
# Split concatenated embeddings for each image item.
|
||||||
# Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
|
|
||||||
merge_size = self.visual.spatial_merge_size
|
merge_size = self.visual.spatial_merge_size
|
||||||
sizes = (
|
sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
|
||||||
torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
|
|
||||||
// (merge_size * merge_size)
|
|
||||||
).tolist()
|
|
||||||
return image_embeds.split(sizes)
|
return image_embeds.split(sizes)
|
||||||
|
|
||||||
def _process_video_input(
|
def _process_video_input(
|
||||||
@ -1383,7 +1374,6 @@ class Qwen3VLForConditionalGeneration(
|
|||||||
) -> tuple[torch.Tensor, ...]:
|
) -> tuple[torch.Tensor, ...]:
|
||||||
grid_thw = video_input["video_grid_thw"]
|
grid_thw = video_input["video_grid_thw"]
|
||||||
assert grid_thw.ndim == 2
|
assert grid_thw.ndim == 2
|
||||||
grid_thw_list = grid_thw.tolist()
|
|
||||||
|
|
||||||
if video_input["type"] == "video_embeds":
|
if video_input["type"] == "video_embeds":
|
||||||
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
|
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
|
||||||
@ -1392,19 +1382,16 @@ class Qwen3VLForConditionalGeneration(
|
|||||||
self.visual.dtype
|
self.visual.dtype
|
||||||
)
|
)
|
||||||
if self.use_data_parallel:
|
if self.use_data_parallel:
|
||||||
|
grid_thw_list = grid_thw.tolist()
|
||||||
return run_dp_sharded_mrope_vision_model(
|
return run_dp_sharded_mrope_vision_model(
|
||||||
self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d"
|
self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw_list)
|
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
|
||||||
|
|
||||||
# Split concatenated embeddings for each video item.
|
# Split concatenated embeddings for each video item.
|
||||||
# Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
|
|
||||||
merge_size = self.visual.spatial_merge_size
|
merge_size = self.visual.spatial_merge_size
|
||||||
sizes = (
|
sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
|
||||||
torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
|
|
||||||
// (merge_size * merge_size)
|
|
||||||
).tolist()
|
|
||||||
return video_embeds.split(sizes)
|
return video_embeds.split(sizes)
|
||||||
|
|
||||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import atexit
|
import atexit
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable, Set
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from itertools import groupby
|
from itertools import groupby
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -402,6 +402,7 @@ def group_mm_kwargs_by_modality(
|
|||||||
device: torch.types.Device = None,
|
device: torch.types.Device = None,
|
||||||
pin_memory: bool = False,
|
pin_memory: bool = False,
|
||||||
merge_by_field_config: bool | None = None,
|
merge_by_field_config: bool | None = None,
|
||||||
|
multimodal_cpu_fields: Set[str] = frozenset(),
|
||||||
) -> Iterable[tuple[str, int, BatchedTensorInputs]]:
|
) -> Iterable[tuple[str, int, BatchedTensorInputs]]:
|
||||||
"""Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
|
"""Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
|
||||||
modality together into the same `MultiModalKwargs` instance.
|
modality together into the same `MultiModalKwargs` instance.
|
||||||
@ -443,12 +444,17 @@ def group_mm_kwargs_by_modality(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if device is not None:
|
if device is not None:
|
||||||
mm_kwargs_group = json_map_leaves(
|
mm_kwargs_group = {
|
||||||
lambda x: x.to(device=device, non_blocking=True)
|
k: json_map_leaves(
|
||||||
if isinstance(x, torch.Tensor)
|
lambda x: x.to(device=device, non_blocking=True)
|
||||||
else x,
|
if isinstance(x, torch.Tensor)
|
||||||
mm_kwargs_group,
|
else x,
|
||||||
)
|
v,
|
||||||
|
)
|
||||||
|
if k not in multimodal_cpu_fields
|
||||||
|
else v
|
||||||
|
for k, v in mm_kwargs_group.items()
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
mm_kwargs_group = MultiModalKwargs.as_kwargs(
|
mm_kwargs_group = MultiModalKwargs.as_kwargs(
|
||||||
MultiModalKwargs.batch(
|
MultiModalKwargs.batch(
|
||||||
|
|||||||
@ -938,6 +938,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
pin_memory=self.pin_memory,
|
pin_memory=self.pin_memory,
|
||||||
merge_by_field_config=model.merge_by_field_config,
|
merge_by_field_config=model.merge_by_field_config,
|
||||||
|
multimodal_cpu_fields=model.multimodal_cpu_fields,
|
||||||
):
|
):
|
||||||
mm_kwargs_combined.update(mm_kwargs_group)
|
mm_kwargs_combined.update(mm_kwargs_group)
|
||||||
|
|
||||||
@ -1768,6 +1769,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
pin_memory=self.pin_memory,
|
pin_memory=self.pin_memory,
|
||||||
merge_by_field_config=model.merge_by_field_config,
|
merge_by_field_config=model.merge_by_field_config,
|
||||||
|
multimodal_cpu_fields=model.multimodal_cpu_fields,
|
||||||
):
|
):
|
||||||
curr_group_outputs = []
|
curr_group_outputs = []
|
||||||
|
|
||||||
@ -1794,6 +1796,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
pin_memory=self.pin_memory,
|
pin_memory=self.pin_memory,
|
||||||
merge_by_field_config=model.merge_by_field_config,
|
merge_by_field_config=model.merge_by_field_config,
|
||||||
|
multimodal_cpu_fields=model.multimodal_cpu_fields,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1936,6 +1939,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
pin_memory=self.pin_memory,
|
pin_memory=self.pin_memory,
|
||||||
merge_by_field_config=model.merge_by_field_config,
|
merge_by_field_config=model.merge_by_field_config,
|
||||||
|
multimodal_cpu_fields=model.multimodal_cpu_fields,
|
||||||
):
|
):
|
||||||
# Add the grouped features to encoder_features dict
|
# Add the grouped features to encoder_features dict
|
||||||
# This allows the model to receive them as kwargs (e.g.,
|
# This allows the model to receive them as kwargs (e.g.,
|
||||||
@ -3292,6 +3296,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
pin_memory=self.pin_memory,
|
pin_memory=self.pin_memory,
|
||||||
merge_by_field_config=model.merge_by_field_config,
|
merge_by_field_config=model.merge_by_field_config,
|
||||||
|
multimodal_cpu_fields=model.multimodal_cpu_fields,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -952,6 +952,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
pin_memory=self.pin_memory,
|
pin_memory=self.pin_memory,
|
||||||
merge_by_field_config=model.merge_by_field_config,
|
merge_by_field_config=model.merge_by_field_config,
|
||||||
|
multimodal_cpu_fields=model.multimodal_cpu_fields,
|
||||||
):
|
):
|
||||||
# Run the encoder.
|
# Run the encoder.
|
||||||
# `curr_group_outputs` is either of the following:
|
# `curr_group_outputs` is either of the following:
|
||||||
@ -2037,6 +2038,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
pin_memory=self.pin_memory,
|
pin_memory=self.pin_memory,
|
||||||
merge_by_field_config=model.merge_by_field_config,
|
merge_by_field_config=model.merge_by_field_config,
|
||||||
|
multimodal_cpu_fields=model.multimodal_cpu_fields,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user