diff --git a/vllm/config.py b/vllm/config.py index 8fc8ae6b7dfc5..9684cea813134 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -883,7 +883,7 @@ class ParallelConfig: from vllm.executor import ray_utils backend = "mp" ray_found = ray_utils.ray_is_available() - if (torch.cuda.is_available() + if (current_platform.is_cuda() and cuda_device_count_stateless() < self.world_size): if not ray_found: raise ValueError("Unable to load Ray which is "