[Misc]Replace cuda hard code with current_platform in Ray (#14668)

Signed-off-by: noemotiovon <757486878@qq.com>
This commit is contained in:
Chenguang Li 2025-05-25 11:26:31 +08:00 committed by GitHub
parent 6c6dcd8611
commit cebc22f3b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -87,9 +87,8 @@ try:
# TODO(swang): This is needed right now because Ray Compiled Graph
# executes on a background thread, so we need to reset torch's
# current device.
import torch
if not self.compiled_dag_cuda_device_set:
torch.cuda.set_device(self.worker.device)
current_platform.set_device(self.worker.device)
self.compiled_dag_cuda_device_set = True
output = self.worker._execute_model_spmd(execute_model_req,
@ -113,8 +112,7 @@ try:
# Not needed
pass
else:
import torch
torch.cuda.set_device(self.worker.device)
current_platform.set_device(self.worker.device)
self.compiled_dag_cuda_device_set = True