mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-19 23:57:03 +08:00
[Misc]Replace cuda hard code with current_platform in Ray (#14668)
Signed-off-by: noemotiovon <757486878@qq.com>
This commit is contained in:
parent
6c6dcd8611
commit
cebc22f3b6
@ -87,9 +87,8 @@ try:
|
|||||||
# TODO(swang): This is needed right now because Ray Compiled Graph
|
# TODO(swang): This is needed right now because Ray Compiled Graph
|
||||||
# executes on a background thread, so we need to reset torch's
|
# executes on a background thread, so we need to reset torch's
|
||||||
# current device.
|
# current device.
|
||||||
import torch
|
|
||||||
if not self.compiled_dag_cuda_device_set:
|
if not self.compiled_dag_cuda_device_set:
|
||||||
torch.cuda.set_device(self.worker.device)
|
current_platform.set_device(self.worker.device)
|
||||||
self.compiled_dag_cuda_device_set = True
|
self.compiled_dag_cuda_device_set = True
|
||||||
|
|
||||||
output = self.worker._execute_model_spmd(execute_model_req,
|
output = self.worker._execute_model_spmd(execute_model_req,
|
||||||
@ -113,8 +112,7 @@ try:
|
|||||||
# Not needed
|
# Not needed
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
import torch
|
current_platform.set_device(self.worker.device)
|
||||||
torch.cuda.set_device(self.worker.device)
|
|
||||||
|
|
||||||
self.compiled_dag_cuda_device_set = True
|
self.compiled_dag_cuda_device_set = True
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user