mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-01 06:08:43 +08:00
Avoid unnecessary multi-modal input data copy when len(batch) == 1 (#12722)
Signed-off-by: imkero <kerorek@outlook.com>
This commit is contained in:
parent
6469038b14
commit
62467a834a
@ -212,6 +212,11 @@ class MultiModalBatchedField(BaseMultiModalField):
|
||||
|
||||
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
||||
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
|
||||
if len(batch) == 1:
|
||||
# An optimization when `batch` contains only one tensor:
|
||||
# - produce exactly same result as `torch.stack(batch)`
|
||||
# - will achieve zero-copy if the tensor is contiguous
|
||||
return batch[0].unsqueeze(0).contiguous()
|
||||
first_shape = batch[0].shape
|
||||
if all(elem.shape == first_shape for elem in batch):
|
||||
return torch.stack(batch)
|
||||
@ -235,6 +240,11 @@ class MultiModalFlatField(BaseMultiModalField):
|
||||
|
||||
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
||||
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
|
||||
if len(batch) == 1:
|
||||
# An optimization when `batch` contains only one tensor:
|
||||
# - produce exactly same result as `torch.concat(batch)`
|
||||
# - will achieve zero-copy if the tensor is contiguous
|
||||
return batch[0].contiguous()
|
||||
first_shape = batch[0].shape
|
||||
if all(elem.shape[1:] == first_shape[1:] for elem in batch):
|
||||
return torch.concat(batch)
|
||||
@ -407,6 +417,12 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
|
||||
return stacked
|
||||
|
||||
tensors_ = cast(list[torch.Tensor], stacked)
|
||||
if len(tensors_) == 1:
|
||||
# An optimization when `tensors_` contains only one tensor:
|
||||
# - produce exactly same result as `torch.stack(tensors_)`
|
||||
# - will achieve zero-copy if the tensor is contiguous
|
||||
return tensors_[0].unsqueeze(0).contiguous()
|
||||
|
||||
if any(t.shape != tensors_[0].shape for t in tensors_):
|
||||
# The tensors have incompatible shapes and can't be stacked.
|
||||
return tensors_
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user