mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-12 09:07:12 +08:00
Avoid unnecessary multi-modal input data copy when len(batch) == 1 (#12722)
Signed-off-by: imkero <kerorek@outlook.com>
This commit is contained in:
parent
6469038b14
commit
62467a834a
@ -212,6 +212,11 @@ class MultiModalBatchedField(BaseMultiModalField):
|
|||||||
|
|
||||||
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
||||||
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
|
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
|
||||||
|
if len(batch) == 1:
|
||||||
|
# An optimization when `batch` contains only one tensor:
|
||||||
|
# - produce exactly same result as `torch.stack(batch)`
|
||||||
|
# - will achieve zero-copy if the tensor is contiguous
|
||||||
|
return batch[0].unsqueeze(0).contiguous()
|
||||||
first_shape = batch[0].shape
|
first_shape = batch[0].shape
|
||||||
if all(elem.shape == first_shape for elem in batch):
|
if all(elem.shape == first_shape for elem in batch):
|
||||||
return torch.stack(batch)
|
return torch.stack(batch)
|
||||||
@ -235,6 +240,11 @@ class MultiModalFlatField(BaseMultiModalField):
|
|||||||
|
|
||||||
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
||||||
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
|
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
|
||||||
|
if len(batch) == 1:
|
||||||
|
# An optimization when `batch` contains only one tensor:
|
||||||
|
# - produce exactly same result as `torch.concat(batch)`
|
||||||
|
# - will achieve zero-copy if the tensor is contiguous
|
||||||
|
return batch[0].contiguous()
|
||||||
first_shape = batch[0].shape
|
first_shape = batch[0].shape
|
||||||
if all(elem.shape[1:] == first_shape[1:] for elem in batch):
|
if all(elem.shape[1:] == first_shape[1:] for elem in batch):
|
||||||
return torch.concat(batch)
|
return torch.concat(batch)
|
||||||
@ -407,6 +417,12 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
|
|||||||
return stacked
|
return stacked
|
||||||
|
|
||||||
tensors_ = cast(list[torch.Tensor], stacked)
|
tensors_ = cast(list[torch.Tensor], stacked)
|
||||||
|
if len(tensors_) == 1:
|
||||||
|
# An optimization when `tensors_` contains only one tensor:
|
||||||
|
# - produce exactly same result as `torch.stack(tensors_)`
|
||||||
|
# - will achieve zero-copy if the tensor is contiguous
|
||||||
|
return tensors_[0].unsqueeze(0).contiguous()
|
||||||
|
|
||||||
if any(t.shape != tensors_[0].shape for t in tensors_):
|
if any(t.shape != tensors_[0].shape for t in tensors_):
|
||||||
# The tensors have incompatible shapes and can't be stacked.
|
# The tensors have incompatible shapes and can't be stacked.
|
||||||
return tensors_
|
return tensors_
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user