Add TPU to device config

This commit is contained in:
Woosuk Kwon 2024-04-01 08:23:44 +00:00
parent 02e614d922
commit 38e3d33a62

View File

@ -596,6 +596,9 @@ class DeviceConfig:
# Some device types require processing inputs on CPU
if self.device_type in ["neuron"]:
self.device = torch.device("cpu")
elif self.device_type in ["tpu"]:
# Will be set by `xm.xla_device()`
self.device = None
else:
# Set device with device type
self.device = torch.device(self.device_type)