diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 37cc07bfbb36a..7bc98a16f041d 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -87,9 +87,8 @@ try: # TODO(swang): This is needed right now because Ray Compiled Graph # executes on a background thread, so we need to reset torch's # current device. - import torch if not self.compiled_dag_cuda_device_set: - torch.cuda.set_device(self.worker.device) + current_platform.set_device(self.worker.device) self.compiled_dag_cuda_device_set = True output = self.worker._execute_model_spmd(execute_model_req, @@ -113,8 +112,7 @@ try: # Not needed pass else: - import torch - torch.cuda.set_device(self.worker.device) + current_platform.set_device(self.worker.device) self.compiled_dag_cuda_device_set = True