Avoid unnecessary multi-modal input data copy when len(batch) == 1 (#12722)

Signed-off-by: imkero <kerorek@outlook.com>
This commit is contained in:
Kero Liang 2025-02-04 21:03:19 +08:00 committed by GitHub
parent 6469038b14
commit 62467a834a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -212,6 +212,11 @@ class MultiModalBatchedField(BaseMultiModalField):
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
if len(batch) == 1:
# An optimization when `batch` contains only one tensor:
# - produce exactly same result as `torch.stack(batch)`
# - will achieve zero-copy if the tensor is contiguous
return batch[0].unsqueeze(0).contiguous()
first_shape = batch[0].shape
if all(elem.shape == first_shape for elem in batch):
return torch.stack(batch)
@ -235,6 +240,11 @@ class MultiModalFlatField(BaseMultiModalField):
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
if len(batch) == 1:
# An optimization when `batch` contains only one tensor:
# - produce exactly same result as `torch.concat(batch)`
# - will achieve zero-copy if the tensor is contiguous
return batch[0].contiguous()
first_shape = batch[0].shape
if all(elem.shape[1:] == first_shape[1:] for elem in batch):
return torch.concat(batch)
@ -407,6 +417,12 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
return stacked
tensors_ = cast(list[torch.Tensor], stacked)
if len(tensors_) == 1:
# An optimization when `tensors_` contains only one tensor:
# - produce exactly same result as `torch.stack(tensors_)`
# - will achieve zero-copy if the tensor is contiguous
return tensors_[0].unsqueeze(0).contiguous()
if any(t.shape != tensors_[0].shape for t in tensors_):
# The tensors have incompatible shapes and can't be stacked.
return tensors_