[BugFix] Ray with multiple nodes (#28873)

Signed-off-by: Julien Denize <julien.denize@mistral.ai>
This commit is contained in:
Julien Denize 2025-11-19 22:20:58 +01:00 committed by GitHub
parent 1607e664f0
commit cdeec2e606
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -205,14 +205,14 @@ class Worker(WorkerBase):
assert self.local_rank < torch.cuda.device_count(), (
f"DP adjusted local rank {self.local_rank} is out of bounds. "
)
visible_device_count = (
torch.cuda.device_count() if torch.cuda.is_available() else 0
)
assert self.parallel_config.local_world_size <= visible_device_count, (
f"local_world_size ({self.parallel_config.local_world_size}) must be "
f"less than or equal to the number of visible devices "
f"({visible_device_count})."
)
visible_device_count = (
torch.cuda.device_count() if torch.cuda.is_available() else 0
)
assert self.parallel_config.local_world_size <= visible_device_count, (
f"local_world_size ({self.parallel_config.local_world_size}) must "
f"be less than or equal to the number of visible devices "
f"({visible_device_count})."
)
self.device = torch.device(f"cuda:{self.local_rank}")
current_platform.set_device(self.device)