[BugFix] Avoid initializing CUDA too early (#3487)

This commit is contained in:
Nick Hill 2024-03-18 23:05:20 -07:00 committed by GitHub
parent ef65dcfa6f
commit 7341c77d69
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -577,12 +577,12 @@ class DeviceConfig:
def __init__(self, device: str = "auto") -> None:
if device == "auto":
# Automated device type detection
if torch.cuda.is_available():
self.device_type = "cuda"
elif is_neuron():
if is_neuron():
self.device_type = "neuron"
else:
raise RuntimeError("No supported device detected.")
# We don't call torch.cuda.is_available() here to
# avoid initializing CUDA before workers are forked
self.device_type = "cuda"
else:
# Device type is assigned explicitly
self.device_type = device