From a425bd9a9af6b32e0e93b2787d6682a1a5133983 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 26 Sep 2023 10:21:08 -0700 Subject: [PATCH] [Setup] Enable `TORCH_CUDA_ARCH_LIST` for selecting target GPUs (#1074) --- setup.py | 111 ++++++++++++++++++++++++++++++++++++------------------- 1 file changed, 73 insertions(+), 38 deletions(-) diff --git a/setup.py b/setup.py index 047ee8d0e894..b7c0d9071fec 100644 --- a/setup.py +++ b/setup.py @@ -3,6 +3,7 @@ import os import re import subprocess from typing import List, Set +import warnings from packaging.version import parse, Version import setuptools @@ -11,6 +12,9 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME ROOT_DIR = os.path.dirname(__file__) +# Supported NVIDIA GPU architectures. +SUPPORTED_ARCHS = ["7.0", "7.5", "8.0", "8.6", "8.9", "9.0"] + # Compiler flags. CXX_FLAGS = ["-g", "-O2", "-std=c++17"] # TODO(woosuk): Should we use -O3? @@ -38,51 +42,82 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version: return nvcc_cuda_version -# Collect the compute capabilities of all available GPUs. -device_count = torch.cuda.device_count() -compute_capabilities: Set[int] = set() -for i in range(device_count): - major, minor = torch.cuda.get_device_capability(i) - if major < 7: - raise RuntimeError( - "GPUs with compute capability less than 7.0 are not supported.") - compute_capabilities.add(major * 10 + minor) +def get_torch_arch_list() -> Set[str]: + # TORCH_CUDA_ARCH_LIST can have one or more architectures, + # e.g. "8.0" or "7.5,8.0,8.6+PTX". Here, the "8.6+PTX" option asks the + # compiler to additionally include PTX code that can be runtime-compiled + # and executed on the 8.6 or newer architectures. While the PTX code will + # not give the best performance on the newer architectures, it provides + # forward compatibility. + valid_arch_strs = SUPPORTED_ARCHS + [s + "+PTX" for s in SUPPORTED_ARCHS] + arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None) + if arch_list is None: + return set() + + # List are separated by ; or space. + arch_list = arch_list.replace(" ", ";").split(";") + for arch in arch_list: + if arch not in valid_arch_strs: + raise ValueError( + f"Unsupported CUDA arch ({arch}). " + f"Valid CUDA arch strings are: {valid_arch_strs}.") + return set(arch_list) + + +# First, check the TORCH_CUDA_ARCH_LIST environment variable. +compute_capabilities = get_torch_arch_list() +if not compute_capabilities: + # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available + # GPUs on the current machine. + device_count = torch.cuda.device_count() + for i in range(device_count): + major, minor = torch.cuda.get_device_capability(i) + if major < 7: + raise RuntimeError( + "GPUs with compute capability below 7.0 are not supported.") + compute_capabilities.add(f"{major}.{minor}") + +nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) +if not compute_capabilities: + # If no GPU is specified nor available, add all supported architectures + # based on the NVCC CUDA version. + compute_capabilities = set(SUPPORTED_ARCHS) + if nvcc_cuda_version < Version("11.1"): + compute_capabilities.remove("8.6") + if nvcc_cuda_version < Version("11.8"): + compute_capabilities.remove("8.9") + compute_capabilities.remove("9.0") # Validate the NVCC CUDA version. -nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) if nvcc_cuda_version < Version("11.0"): raise RuntimeError("CUDA 11.0 or higher is required to build the package.") -if 86 in compute_capabilities and nvcc_cuda_version < Version("11.1"): - raise RuntimeError( - "CUDA 11.1 or higher is required for GPUs with compute capability 8.6." - ) -if 89 in compute_capabilities and nvcc_cuda_version < Version("11.8"): - # CUDA 11.8 is required to generate the code targeting compute capability 8.9. - # However, GPUs with compute capability 8.9 can also run the code generated by - # the previous versions of CUDA 11 and targeting compute capability 8.0. - # Therefore, if CUDA 11.8 is not available, we target compute capability 8.0 - # instead of 8.9. - compute_capabilities.remove(89) - compute_capabilities.add(80) -if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"): - raise RuntimeError( - "CUDA 11.8 or higher is required for GPUs with compute capability 9.0." - ) - -# If no GPU is available, add all supported compute capabilities. -if not compute_capabilities: - compute_capabilities = {70, 75, 80} - if nvcc_cuda_version >= Version("11.1"): - compute_capabilities.add(86) - if nvcc_cuda_version >= Version("11.8"): - compute_capabilities.add(89) - compute_capabilities.add(90) +if nvcc_cuda_version < Version("11.1"): + if any(cc.startswith("8.6") for cc in compute_capabilities): + raise RuntimeError( + "CUDA 11.1 or higher is required for compute capability 8.6.") +if nvcc_cuda_version < Version("11.8"): + if any(cc.startswith("8.9") for cc in compute_capabilities): + # CUDA 11.8 is required to generate the code targeting compute capability 8.9. + # However, GPUs with compute capability 8.9 can also run the code generated by + # the previous versions of CUDA 11 and targeting compute capability 8.0. + # Therefore, if CUDA 11.8 is not available, we target compute capability 8.0 + # instead of 8.9. + warnings.warn( + "CUDA 11.8 or higher is required for compute capability 8.9. " + "Targeting compute capability 8.0 instead.") + compute_capabilities = set(cc for cc in compute_capabilities + if not cc.startswith("8.9")) + compute_capabilities.add("8.0+PTX") + if any(cc.startswith("9.0") for cc in compute_capabilities): + raise RuntimeError( + "CUDA 11.8 or higher is required for compute capability 9.0.") # Add target compute capabilities to NVCC flags. for capability in compute_capabilities: - NVCC_FLAGS += [ - "-gencode", f"arch=compute_{capability},code=sm_{capability}" - ] + num = capability[0] + capability[2] + NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"] + if capability.endswith("+PTX"): + NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"] # Use NVCC threads to parallelize the build. if nvcc_cuda_version >= Version("11.2"):