diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 0bf5b1cf1c6c7..5cb720381d94b 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -680,7 +680,8 @@ class MultiModalKwargs(UserDict[str, NestedTensors]): return self._items_by_modality.keys() @staticmethod - def _try_stack(nested_tensors: NestedTensors) -> NestedTensors: + def _try_stack(nested_tensors: NestedTensors, + pin_memory: bool = False) -> NestedTensors: """ Stack the inner dimensions that have the same shape in a nested list of tensors. @@ -697,7 +698,9 @@ class MultiModalKwargs(UserDict[str, NestedTensors]): if isinstance(nested_tensors, (int, float)): return torch.tensor(nested_tensors) - stacked = [MultiModalKwargs._try_stack(t) for t in nested_tensors] + stacked = [ + MultiModalKwargs._try_stack(t, pin_memory) for t in nested_tensors + ] if not is_list_of(stacked, torch.Tensor, check="all"): # Only tensors (not lists) can be stacked. return stacked @@ -713,10 +716,16 @@ class MultiModalKwargs(UserDict[str, NestedTensors]): # The tensors have incompatible shapes and can't be stacked. return tensors_ - return torch.stack(tensors_) + outputs = torch.empty(len(tensors_), + *tensors_[0].shape, + dtype=tensors_[0].dtype, + device=tensors_[0].device, + pin_memory=pin_memory) + return torch.stack(tensors_, out=outputs) @staticmethod - def batch(inputs_list: list["MultiModalKwargs"]) -> BatchedTensorInputs: + def batch(inputs_list: list["MultiModalKwargs"], + pin_memory: bool = False) -> BatchedTensorInputs: """ Batch multiple inputs together into a dictionary. @@ -738,7 +747,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]): item_lists[k].append(v) return { - k: MultiModalKwargs._try_stack(item_list) + k: MultiModalKwargs._try_stack(item_list, pin_memory) for k, item_list in item_lists.items() } diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 175404efe0455..b1bc727e1e8ea 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -962,7 +962,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): encoder_outputs = [] for grouped_mm_inputs in grouped_mm_inputs_list: - batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs) + batched_mm_inputs = MultiModalKwargs.batch( + grouped_mm_inputs, pin_memory=self.pin_memory) batched_mm_inputs = MultiModalKwargs.as_kwargs( batched_mm_inputs, device=self.device, @@ -1989,7 +1990,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ).multi_modal_data batched_dummy_mm_inputs = MultiModalKwargs.batch( - [dummy_mm_kwargs] * max_num_mm_items) + [dummy_mm_kwargs] * max_num_mm_items, + pin_memory=self.pin_memory) batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs( batched_dummy_mm_inputs, device=self.device,