From 6e588da0f4b90e695a20779c3d5a079e56ad3a7b Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Thu, 22 May 2025 15:13:54 -0400 Subject: [PATCH] [Build/CI] Fix CUDA 11.8 build (#17679) Signed-off-by: Tyler Michael Smith Signed-off-by: Lucas Wilkinson Signed-off-by: Tyler Michael Smith Co-authored-by: Lucas Wilkinson --- CMakeLists.txt | 6 ++- csrc/moe/moe_ops.h | 4 +- csrc/moe/moe_permute_unpermute_op.cu | 43 ++++++++++++++++++- .../moe_permute_unpermute_kernel.cu | 10 +++-- csrc/moe/torch_bindings.cpp | 4 +- .../cutlass_w8a8/scaled_mm_entry.cu | 2 +- docker/Dockerfile | 16 ++++--- .../kernels/moe/test_moe_permute_unpermute.py | 4 +- .../layers/fused_moe/moe_permute_unpermute.py | 4 ++ 9 files changed, 78 insertions(+), 15 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index a6c54be9530b9..ffb801d62619d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -30,7 +30,11 @@ set(ignoreMe "${VLLM_PYTHON_PATH}") set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12") # Supported NVIDIA architectures. -set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0") +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL) + set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0") +else() + set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0") +endif() # Supported AMD GPU architectures. set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201") diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 0bae119a7c460..8fda434d452f9 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -28,4 +28,6 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, torch::Tensor num_tokens_post_pad, int64_t top_k, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, int64_t BLOCK_SIZE_K, int64_t bit); -#endif \ No newline at end of file +#endif + +bool moe_permute_unpermute_supported(); \ No newline at end of file diff --git a/csrc/moe/moe_permute_unpermute_op.cu b/csrc/moe/moe_permute_unpermute_op.cu index 76d5f0eab0218..9a7465261abfe 100644 --- a/csrc/moe/moe_permute_unpermute_op.cu +++ b/csrc/moe/moe_permute_unpermute_op.cu @@ -5,6 +5,9 @@ #include "permute_unpermute_kernels/dispatch.h" #include "core/registration.h" +// moe_permute kernels require at least CUDA 12.0 +#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000) + void moe_permute( const torch::Tensor& input, // [n_token, hidden] const torch::Tensor& topk_weights, //[n_token, topk] @@ -127,7 +130,45 @@ void moe_unpermute( }); } +#else + +void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights, + torch::Tensor& topk_ids, + const torch::Tensor& token_expert_indicies, + const std::optional& expert_map, + int64_t n_expert, int64_t n_local_expert, int64_t topk, + const std::optional& align_block_size, + torch::Tensor& permuted_input, + torch::Tensor& expert_first_token_offset, + torch::Tensor& src_row_id2dst_row_id_map, + torch::Tensor& m_indices) { + TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0"); +} + +void moe_unpermute(const torch::Tensor& input, + const torch::Tensor& topk_weights, torch::Tensor& topk_ids, + const torch::Tensor& token_expert_indicies, + const std::optional& expert_map, + int64_t n_expert, int64_t n_local_expert, int64_t topk, + const std::optional& align_block_size, + torch::Tensor& permuted_input, + torch::Tensor& expert_first_token_offset, + torch::Tensor& src_row_id2dst_row_id_map, + torch::Tensor& m_indices) { + TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0"); +} + +#endif + +bool moe_permute_unpermute_supported() { +#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000) + return true; +#else + return false; +#endif +} + TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { m.impl("moe_permute", &moe_permute); m.impl("moe_unpermute", &moe_unpermute); -} \ No newline at end of file +} diff --git a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu index aa353d0f0437f..de2c153882d93 100644 --- a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu +++ b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu @@ -1,6 +1,9 @@ #include "moe_permute_unpermute_kernel.h" +// moe_permute kernels require at least CUDA 12.0 +#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000) + // CubKeyValueSorter definition begin CubKeyValueSorter::CubKeyValueSorter() : num_experts_(0), num_bits_(sizeof(int) * 8) {} @@ -131,9 +134,6 @@ __global__ void preprocessTopkIdKernel(int* topk_id_ptr, int size, int num_experts) { auto tidx = threadIdx.x; auto bidx = blockIdx.x; - auto lidx = tidx & 31; - auto widx = tidx >> 5; - auto warp_count = (blockDim.x + 31) >> 5; auto offset = bidx * blockDim.x; auto bound = min(offset + blockDim.x, size); extern __shared__ int smem_expert_map[]; @@ -226,4 +226,6 @@ void getMIndices(int64_t* expert_first_token_offset, expert_first_token_offset, align_expert_first_token_offset, m_indices, num_local_expert, align_block_size); } -} \ No newline at end of file +} + +#endif diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 05f515e2e783b..7d35ec79ead48 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -77,7 +77,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "Tensor topk_ids,Tensor src_row_id2dst_row_id_map, Tensor " "expert_first_token_offset, int n_expert, int n_local_expert,int " "topk, Tensor! hidden_states)->()"); - // conditionally compiled so impl registration is in source file + + m.def("moe_permute_unpermute_supported() -> bool"); + m.impl("moe_permute_unpermute_supported", &moe_permute_unpermute_supported); #endif } diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 3c258ddce61e6..e9b408fbf2ee0 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -123,7 +123,7 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) { } bool cutlass_group_gemm_supported(int64_t cuda_device_capability) { - // CUTLASS groped FP8 kernels need at least CUDA 12.3 + // CUTLASS grouped FP8 kernels need at least CUDA 12.3 // and SM90 (Hopper) #if defined CUDA_VERSION diff --git a/docker/Dockerfile b/docker/Dockerfile index a35056f785879..cc3499d1f0a91 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -263,8 +263,11 @@ if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \ export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0 10.0+PTX'; \ else \ export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0+PTX'; \ - fi && \ - export FLASHINFER_ENABLE_AOT=1; \ + fi; \ + CUDA_MAJOR="${CUDA_VERSION%%.*}"; \ + if [ "$CUDA_MAJOR" -lt 12 ]; then \ + export FLASHINFER_ENABLE_SM90=0; \ + fi; \ uv pip install --system --no-build-isolation "git+https://github.com/flashinfer-ai/flashinfer@21ea1d2545f74782b91eb8c08fd503ac4c0743fc" ; \ fi COPY examples examples @@ -275,7 +278,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ . /etc/environment && \ uv pip list -# Although we build Flashinfer with AOT mode, there's still +# Even when we build Flashinfer with AOT mode, there's still # some issues w.r.t. JIT compilation. Therefore we need to # install build dependencies for JIT compilation. # TODO: Remove this once FlashInfer AOT wheel is fixed @@ -303,8 +306,11 @@ RUN --mount=type=cache,target=/root/.cache/uv \ uv pip install --system --no-build-isolation "git+https://github.com/state-spaces/mamba@v2.2.4" # install development dependencies (for testing) -RUN --mount=type=cache,target=/root/.cache/uv \ - uv pip install --system -r requirements/dev.txt +RUN --mount=type=cache,target=/root/.cache/uv \ + CUDA_MAJOR="${CUDA_VERSION%%.*}"; \ + if [ "$CUDA_MAJOR" -ge 12 ]; then \ + uv pip install --system -r requirements/dev.txt; \ + fi # install development dependencies (for testing) RUN --mount=type=cache,target=/root/.cache/uv \ diff --git a/tests/kernels/moe/test_moe_permute_unpermute.py b/tests/kernels/moe/test_moe_permute_unpermute.py index dfcd61f775870..10e6ac64df877 100644 --- a/tests/kernels/moe/test_moe_permute_unpermute.py +++ b/tests/kernels/moe/test_moe_permute_unpermute.py @@ -13,7 +13,7 @@ import torch from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.layer import determine_expert_map from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( - moe_permute, moe_unpermute) + moe_permute, moe_permute_unpermute_supported, moe_unpermute) from vllm.platforms import current_platform NUM_EXPERTS = [16, 64] @@ -167,6 +167,8 @@ def torch_unpermute(permuted_hidden_states: torch.Tensor, def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int, n_expert: int, ep_size: int, dtype: torch.dtype, align_block_size: Optional[int]): + if not moe_permute_unpermute_supported(): + pytest.skip("moe_permute_unpermute is not supported on this platform.") fill_invalid_expert = 0 ep_rank = np.random.randint(0, ep_size) expert_map = None diff --git a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py index 270e7cf1298ab..cb396f26c96e0 100644 --- a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py +++ b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py @@ -182,3 +182,7 @@ def moe_unpermute( expert_first_token_offset, n_expert, n_local_expert, topk, hidden_states) return hidden_states + + +def moe_permute_unpermute_supported(): + return torch.ops._moe_C.moe_permute_unpermute_supported()