diff --git a/CMakeLists.txt b/CMakeLists.txt index ba3991372fabe..cbac8b8397fa3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -86,6 +86,9 @@ find_package(Torch REQUIRED) # Supported NVIDIA architectures. # This check must happen after find_package(Torch) because that's when CMAKE_CUDA_COMPILER_VERSION gets defined if(DEFINED CMAKE_CUDA_COMPILER_VERSION AND + CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0) + set(CUDA_SUPPORTED_ARCHS "7.5;8.0;8.6;8.7;8.9;9.0;10.0;11.0;12.0") +elseif(DEFINED CMAKE_CUDA_COMPILER_VERSION AND CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8) set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0") else() @@ -175,6 +178,15 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}") endif() +# +# Set compression mode for CUDA >=13.x. +# +if(VLLM_GPU_LANG STREQUAL "CUDA" AND + DEFINED CMAKE_CUDA_COMPILER_VERSION AND + CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0) + list(APPEND VLLM_GPU_FLAGS "--compress-mode=size") +endif() + # # Set CUDA include flags for CXX compiler. # @@ -270,7 +282,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") # Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building. - set(CUTLASS_REVISION "v4.0.0" CACHE STRING "CUTLASS revision to use") + set(CUTLASS_REVISION "v4.2.1" CACHE STRING "CUTLASS revision to use") # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR}) @@ -305,7 +317,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu" "csrc/quantization/fp4/nvfp4_quant_entry.cu" "csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu" - "csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu" "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" "csrc/cutlass_extensions/common.cpp" "csrc/quantization/w8a8/fp8/per_token_group_quant.cu" @@ -441,7 +452,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # The cutlass_scaled_mm kernels for Geforce Blackwell SM120 (c3x, i.e. CUTLASS 3.x) require # CUDA 12.8 or later - cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0;12.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0a" "${CUDA_ARCHS}") + endif() if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) set(SRCS "csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm120.cu" @@ -471,7 +486,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x) # require CUDA 12.8 or later - cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}") + endif() if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) set(SRCS "csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm100.cu" @@ -551,7 +570,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # The nvfp4_scaled_mm_sm120 kernels for Geforce Blackwell SM120 require # CUDA 12.8 or later - cuda_archs_loose_intersection(FP4_ARCHS "12.0;12.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(FP4_ARCHS "12.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(FP4_ARCHS "12.0a" "${CUDA_ARCHS}") + endif() if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS) set(SRCS "csrc/quantization/fp4/nvfp4_quant_kernels.cu" @@ -570,7 +593,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() # FP4 Archs and flags - cuda_archs_loose_intersection(FP4_ARCHS "10.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(FP4_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(FP4_ARCHS "10.0a;10.1a;12.0a;12.1a" "${CUDA_ARCHS}") + endif() if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS) set(SRCS "csrc/quantization/fp4/nvfp4_quant_kernels.cu" @@ -592,7 +619,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() # CUTLASS MLA Archs and flags - cuda_archs_loose_intersection(MLA_ARCHS "10.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(MLA_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(MLA_ARCHS "10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}") + endif() if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND MLA_ARCHS) set(SRCS "csrc/attention/mla/sm100_cutlass_mla_kernel.cu") @@ -636,7 +667,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() endif() - cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}") + endif() if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu") set_gencode_flags_for_srcs( @@ -657,7 +692,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() # moe_data.cu is used by all CUTLASS MoE kernels. - cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0f;11.0f;12.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}") + endif() if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS) set(SRCS "csrc/quantization/w8a8/cutlass/moe/moe_data.cu") set_gencode_flags_for_srcs( @@ -676,7 +715,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() endif() - cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}") + endif() if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) set(SRCS "csrc/quantization/w8a8/cutlass/moe/blockwise_scaled_group_mm_sm100.cu") set_gencode_flags_for_srcs( diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 8558976e2c392..f6a0d2b75be1a 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -310,13 +310,13 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR list(REMOVE_DUPLICATES _PTX_ARCHS) list(REMOVE_DUPLICATES _SRC_CUDA_ARCHS) - # if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should - # remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS + # If x.0a or x.0f is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should + # remove x.0a or x.0f from SRC_CUDA_ARCHS and add x.0a or x.0f to _CUDA_ARCHS set(_CUDA_ARCHS) foreach(_arch ${_SRC_CUDA_ARCHS}) - if(_arch MATCHES "\\a$") + if(_arch MATCHES "[af]$") list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}") - string(REPLACE "a" "" _base "${_arch}") + string(REGEX REPLACE "[af]$" "" _base "${_arch}") if ("${_base}" IN_LIST TGT_CUDA_ARCHS) list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}") list(APPEND _CUDA_ARCHS "${_arch}") diff --git a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh index dbf79a0651159..e7bb061ba0244 100644 --- a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh +++ b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh @@ -231,7 +231,7 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out, } else { cutlass_gemm_caller_blockwise, Int>, - Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm, + Shape<_1, _1, _1>, cutlass::epilogue::BlockwiseNoSmemWarpSpecialized1Sm, cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>( out, a, b, a_scales, b_scales); } @@ -245,7 +245,7 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out, } else { cutlass_gemm_caller_blockwise, Int>, - Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm, + Shape<_1, _1, _1>, cutlass::epilogue::BlockwiseNoSmemWarpSpecialized1Sm, cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>( out, a, b, a_scales, b_scales); } @@ -259,7 +259,7 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out, } else { cutlass_gemm_caller_blockwise, Int>, - Shape<_2, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized2Sm, + Shape<_2, _1, _1>, cutlass::epilogue::BlockwiseNoSmemWarpSpecialized2Sm, cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>( out, a, b, a_scales, b_scales); } @@ -271,10 +271,10 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out, // TMA epilogue isn't compatible with Swap A/B cutlass_gemm_caller_blockwise, Int, Int>, - Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm, + Shape<_1, _1, _1>, cutlass::epilogue::BlockwiseNoSmemWarpSpecialized1Sm, cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100, true>>( out, a, b, a_scales, b_scales); } } -} // namespace vllm +} // namespace vllm \ No newline at end of file diff --git a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8_dispatch.cuh b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8_dispatch.cuh index 24564efbd21be..f876b7d9acd87 100644 --- a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8_dispatch.cuh +++ b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8_dispatch.cuh @@ -133,4 +133,4 @@ void cutlass_scaled_mm_sm100_fp8_epilogue(torch::Tensor& out, } } -} // namespace vllm \ No newline at end of file +} // namespace vllm diff --git a/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu b/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu index 84843ee6e0949..04b64a35da376 100644 --- a/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu +++ b/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu @@ -67,8 +67,9 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a, std::optional const& bias); #endif -#if defined(ENABLE_SCALED_MM_SM90) && ENABLE_SCALED_MM_SM90 || \ - defined(ENABLE_SCALED_MM_SM100) && ENABLE_SCALED_MM_SM100 +#if defined(ENABLE_SCALED_MM_SM90) && ENABLE_SCALED_MM_SM90 || \ + defined(ENABLE_SCALED_MM_SM100) && ENABLE_SCALED_MM_SM100 || \ + defined(ENABLE_SCALED_MM_SM120) && ENABLE_SCALED_MM_SM120 void get_cutlass_moe_mm_data_caller( const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,