mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-15 01:37:03 +08:00
[Core] Batch multi modal input using pinned memory (#19169)
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
parent
1efef71645
commit
319cb1e351
@ -680,7 +680,8 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
|
||||
return self._items_by_modality.keys()
|
||||
|
||||
@staticmethod
|
||||
def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
|
||||
def _try_stack(nested_tensors: NestedTensors,
|
||||
pin_memory: bool = False) -> NestedTensors:
|
||||
"""
|
||||
Stack the inner dimensions that have the same shape in
|
||||
a nested list of tensors.
|
||||
@ -697,7 +698,9 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
|
||||
if isinstance(nested_tensors, (int, float)):
|
||||
return torch.tensor(nested_tensors)
|
||||
|
||||
stacked = [MultiModalKwargs._try_stack(t) for t in nested_tensors]
|
||||
stacked = [
|
||||
MultiModalKwargs._try_stack(t, pin_memory) for t in nested_tensors
|
||||
]
|
||||
if not is_list_of(stacked, torch.Tensor, check="all"):
|
||||
# Only tensors (not lists) can be stacked.
|
||||
return stacked
|
||||
@ -713,10 +716,16 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
|
||||
# The tensors have incompatible shapes and can't be stacked.
|
||||
return tensors_
|
||||
|
||||
return torch.stack(tensors_)
|
||||
outputs = torch.empty(len(tensors_),
|
||||
*tensors_[0].shape,
|
||||
dtype=tensors_[0].dtype,
|
||||
device=tensors_[0].device,
|
||||
pin_memory=pin_memory)
|
||||
return torch.stack(tensors_, out=outputs)
|
||||
|
||||
@staticmethod
|
||||
def batch(inputs_list: list["MultiModalKwargs"]) -> BatchedTensorInputs:
|
||||
def batch(inputs_list: list["MultiModalKwargs"],
|
||||
pin_memory: bool = False) -> BatchedTensorInputs:
|
||||
"""
|
||||
Batch multiple inputs together into a dictionary.
|
||||
|
||||
@ -738,7 +747,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
|
||||
item_lists[k].append(v)
|
||||
|
||||
return {
|
||||
k: MultiModalKwargs._try_stack(item_list)
|
||||
k: MultiModalKwargs._try_stack(item_list, pin_memory)
|
||||
for k, item_list in item_lists.items()
|
||||
}
|
||||
|
||||
|
||||
@ -962,7 +962,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
encoder_outputs = []
|
||||
for grouped_mm_inputs in grouped_mm_inputs_list:
|
||||
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
|
||||
batched_mm_inputs = MultiModalKwargs.batch(
|
||||
grouped_mm_inputs, pin_memory=self.pin_memory)
|
||||
batched_mm_inputs = MultiModalKwargs.as_kwargs(
|
||||
batched_mm_inputs,
|
||||
device=self.device,
|
||||
@ -1989,7 +1990,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
).multi_modal_data
|
||||
|
||||
batched_dummy_mm_inputs = MultiModalKwargs.batch(
|
||||
[dummy_mm_kwargs] * max_num_mm_items)
|
||||
[dummy_mm_kwargs] * max_num_mm_items,
|
||||
pin_memory=self.pin_memory)
|
||||
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
|
||||
batched_dummy_mm_inputs,
|
||||
device=self.device,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user