[Core] Batch multi modal input using pinned memory (#19169)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
Lukas Geiger 2025-06-10 07:44:59 +02:00 committed by GitHub
parent 1efef71645
commit 319cb1e351
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 18 additions and 7 deletions

View File

@ -680,7 +680,8 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
return self._items_by_modality.keys()
@staticmethod
def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
def _try_stack(nested_tensors: NestedTensors,
pin_memory: bool = False) -> NestedTensors:
"""
Stack the inner dimensions that have the same shape in
a nested list of tensors.
@ -697,7 +698,9 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
if isinstance(nested_tensors, (int, float)):
return torch.tensor(nested_tensors)
stacked = [MultiModalKwargs._try_stack(t) for t in nested_tensors]
stacked = [
MultiModalKwargs._try_stack(t, pin_memory) for t in nested_tensors
]
if not is_list_of(stacked, torch.Tensor, check="all"):
# Only tensors (not lists) can be stacked.
return stacked
@ -713,10 +716,16 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
# The tensors have incompatible shapes and can't be stacked.
return tensors_
return torch.stack(tensors_)
outputs = torch.empty(len(tensors_),
*tensors_[0].shape,
dtype=tensors_[0].dtype,
device=tensors_[0].device,
pin_memory=pin_memory)
return torch.stack(tensors_, out=outputs)
@staticmethod
def batch(inputs_list: list["MultiModalKwargs"]) -> BatchedTensorInputs:
def batch(inputs_list: list["MultiModalKwargs"],
pin_memory: bool = False) -> BatchedTensorInputs:
"""
Batch multiple inputs together into a dictionary.
@ -738,7 +747,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
item_lists[k].append(v)
return {
k: MultiModalKwargs._try_stack(item_list)
k: MultiModalKwargs._try_stack(item_list, pin_memory)
for k, item_list in item_lists.items()
}

View File

@ -962,7 +962,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
encoder_outputs = []
for grouped_mm_inputs in grouped_mm_inputs_list:
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
batched_mm_inputs = MultiModalKwargs.batch(
grouped_mm_inputs, pin_memory=self.pin_memory)
batched_mm_inputs = MultiModalKwargs.as_kwargs(
batched_mm_inputs,
device=self.device,
@ -1989,7 +1990,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
).multi_modal_data
batched_dummy_mm_inputs = MultiModalKwargs.batch(
[dummy_mm_kwargs] * max_num_mm_items)
[dummy_mm_kwargs] * max_num_mm_items,
pin_memory=self.pin_memory)
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
batched_dummy_mm_inputs,
device=self.device,