From ba8ae1d84f66dd804a97182350fee6ffcadf0faf Mon Sep 17 00:00:00 2001 From: bnellnm <49004751+bnellnm@users.noreply.github.com> Date: Wed, 20 Mar 2024 13:06:56 -0400 Subject: [PATCH] Check for _is_cuda() in compute_num_jobs (#3481) --- setup.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/setup.py b/setup.py index 88787334be21a..67575a0e04bf0 100644 --- a/setup.py +++ b/setup.py @@ -61,12 +61,12 @@ class cmake_build_ext(build_ext): except AttributeError: num_jobs = os.cpu_count() - nvcc_cuda_version = get_nvcc_cuda_version() - if nvcc_cuda_version >= Version("11.2"): - nvcc_threads = int(os.getenv("NVCC_THREADS", 8)) - num_jobs = max(1, round(num_jobs / (nvcc_threads / 4))) - else: - nvcc_threads = None + nvcc_threads = None + if _is_cuda(): + nvcc_cuda_version = get_nvcc_cuda_version() + if nvcc_cuda_version >= Version("11.2"): + nvcc_threads = int(os.getenv("NVCC_THREADS", 8)) + num_jobs = max(1, round(num_jobs / (nvcc_threads / 4))) return num_jobs, nvcc_threads