diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 90f59b5d038ad..44fa3aed5816d 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -271,7 +271,10 @@ def _make_src_to_dst( mapping: List[Tuple[int, int]], src_device: Union[torch.device, str], dst_device: Union[torch.device, str], -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> Optional[Tuple[torch.Tensor, torch.Tensor]]: + if not mapping: + return None + src_indices = [i for i, _ in mapping] dst_indices = [i for _, i in mapping] src_indices = torch.tensor(src_indices,