diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 3fad11a2cb4d..aa61bcc11f9f 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -444,7 +444,9 @@ def group_mm_kwargs_by_modality( if device is not None: mm_kwargs_group = json_map_leaves( - lambda x: x.to(device=device) if isinstance(x, torch.Tensor) else x, + lambda x: x.to(device=device, non_blocking=True) + if isinstance(x, torch.Tensor) + else x, mm_kwargs_group, ) else: