mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-10 09:42:15 +08:00
[BugFix] Avoid initializing CUDA too early (#3487)
This commit is contained in:
parent
ef65dcfa6f
commit
7341c77d69
@ -577,12 +577,12 @@ class DeviceConfig:
|
|||||||
def __init__(self, device: str = "auto") -> None:
|
def __init__(self, device: str = "auto") -> None:
|
||||||
if device == "auto":
|
if device == "auto":
|
||||||
# Automated device type detection
|
# Automated device type detection
|
||||||
if torch.cuda.is_available():
|
if is_neuron():
|
||||||
self.device_type = "cuda"
|
|
||||||
elif is_neuron():
|
|
||||||
self.device_type = "neuron"
|
self.device_type = "neuron"
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("No supported device detected.")
|
# We don't call torch.cuda.is_available() here to
|
||||||
|
# avoid initializing CUDA before workers are forked
|
||||||
|
self.device_type = "cuda"
|
||||||
else:
|
else:
|
||||||
# Device type is assigned explicitly
|
# Device type is assigned explicitly
|
||||||
self.device_type = device
|
self.device_type = device
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user