From 62467a834a94566e2d81a276817a20174b474151 Mon Sep 17 00:00:00 2001 From: Kero Liang Date: Tue, 4 Feb 2025 21:03:19 +0800 Subject: [PATCH] Avoid unnecessary multi-modal input data copy when len(batch) == 1 (#12722) Signed-off-by: imkero --- vllm/multimodal/inputs.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index fe24c7282f3cf..8e4af7f88f911 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -212,6 +212,11 @@ class MultiModalBatchedField(BaseMultiModalField): def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"): + if len(batch) == 1: + # An optimization when `batch` contains only one tensor: + # - produce exactly same result as `torch.stack(batch)` + # - will achieve zero-copy if the tensor is contiguous + return batch[0].unsqueeze(0).contiguous() first_shape = batch[0].shape if all(elem.shape == first_shape for elem in batch): return torch.stack(batch) @@ -235,6 +240,11 @@ class MultiModalFlatField(BaseMultiModalField): def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"): + if len(batch) == 1: + # An optimization when `batch` contains only one tensor: + # - produce exactly same result as `torch.concat(batch)` + # - will achieve zero-copy if the tensor is contiguous + return batch[0].contiguous() first_shape = batch[0].shape if all(elem.shape[1:] == first_shape[1:] for elem in batch): return torch.concat(batch) @@ -407,6 +417,12 @@ class MultiModalKwargs(UserDict[str, NestedTensors]): return stacked tensors_ = cast(list[torch.Tensor], stacked) + if len(tensors_) == 1: + # An optimization when `tensors_` contains only one tensor: + # - produce exactly same result as `torch.stack(tensors_)` + # - will achieve zero-copy if the tensor is contiguous + return tensors_[0].unsqueeze(0).contiguous() + if any(t.shape != tensors_[0].shape for t in tensors_): # The tensors have incompatible shapes and can't be stacked. return tensors_