From 8a3cd90af534c39425ebfdfd295eea0a4582d541 Mon Sep 17 00:00:00 2001 From: Xin Yang <105740670+xyang16@users.noreply.github.com> Date: Mon, 25 Aug 2025 11:47:52 -0700 Subject: [PATCH] [Kernel] Add fused grouped_topk kernel for MoE (#23274) Signed-off-by: Xin Yang Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> --- CMakeLists.txt | 4 +- csrc/moe/grouped_topk_kernels.cu | 757 ++++++++++++++++++ csrc/moe/moe_ops.h | 5 + csrc/moe/torch_bindings.cpp | 6 + tests/kernels/moe/test_grouped_topk.py | 76 ++ vllm/_custom_ops.py | 11 + vllm/envs.py | 6 + .../layers/fused_moe/fused_moe.py | 46 +- 8 files changed, 909 insertions(+), 2 deletions(-) create mode 100644 csrc/moe/grouped_topk_kernels.cu create mode 100644 tests/kernels/moe/test_grouped_topk.py diff --git a/CMakeLists.txt b/CMakeLists.txt index aca42c3fe555..b0ed4a284db9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -817,7 +817,9 @@ set(VLLM_MOE_EXT_SRC "csrc/moe/topk_softmax_kernels.cu") if(VLLM_GPU_LANG STREQUAL "CUDA") - list(APPEND VLLM_MOE_EXT_SRC "csrc/moe/moe_wna16.cu") + list(APPEND VLLM_MOE_EXT_SRC + "csrc/moe/moe_wna16.cu" + "csrc/moe/grouped_topk_kernels.cu") endif() if(VLLM_GPU_LANG STREQUAL "CUDA") diff --git a/csrc/moe/grouped_topk_kernels.cu b/csrc/moe/grouped_topk_kernels.cu new file mode 100644 index 000000000000..78f7b3cc1aa2 --- /dev/null +++ b/csrc/moe/grouped_topk_kernels.cu @@ -0,0 +1,757 @@ +/* + * Adapted from + * https://github.com/NVIDIA/TensorRT-LLM/blob/v0.21.0/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu + * Copyright (c) 2025, The vLLM team. + * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include +namespace cg = cooperative_groups; + +namespace vllm { +namespace moe { + +constexpr unsigned FULL_WARP_MASK = 0xffffffff; +constexpr int32_t WARP_SIZE = 32; +constexpr int32_t BLOCK_SIZE = 512; +constexpr int32_t NUM_WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE; + +namespace warp_topk { + +template +__host__ __device__ constexpr T round_up_to_multiple_of(T len) { + if (len == 0) { + return 0; + } + return ((len - 1) / size + 1) * size; +} + +template +constexpr __host__ __device__ bool isPowerOf2(T v) { + return (v && !(v & (v - 1))); +} + +template +__forceinline__ __device__ bool is_better_than(T val, T baseline) { + return (val > baseline && greater) || (val < baseline && !greater); +} + +template +__forceinline__ __device__ bool is_better_than(T val, T baseline, idxT index, + idxT baseline_index) { + bool res = (val > baseline && greater) || (val < baseline && !greater); + if (val == baseline) { + res = (index < baseline_index && greater) || + (index < baseline_index && !greater); + } + return res; +} + +template +int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) { + int64_t cache_topk = (sizeof(T) + sizeof(idxT)) * num_of_warp * k; + int64_t n = std::max(num_of_warp / 2 * k, num_of_warp * WARP_SIZE); + return max(cache_topk, + round_up_to_multiple_of<256>(n * sizeof(T)) + n * sizeof(idxT)); +} + +template +struct BitonicMerge { + // input should be a bitonic sequence, and sort it to be a monotonic sequence + __device__ static void merge(T* __restrict__ val_arr, + idxT* __restrict__ idx_arr) { + static_assert(isPowerOf2(size)); + static_assert(size >= 2 * WARP_SIZE); + constexpr int arr_len = size / WARP_SIZE; + + constexpr int stride = arr_len / 2; + for (int i = 0; i < stride; ++i) { + int const other_i = i + stride; + T& val = val_arr[i]; + T& other_val = val_arr[other_i]; + bool is_better; + if constexpr (is_stable) { + is_better = is_better_than(val, other_val, idx_arr[i], + idx_arr[other_i]); + } else { + is_better = is_better_than(val, other_val); + } + + if (is_better) { + T tmp = val; + val = other_val; + other_val = tmp; + + idxT tmp2 = idx_arr[i]; + idx_arr[i] = idx_arr[other_i]; + idx_arr[other_i] = tmp2; + } + } + + BitonicMerge::merge( + val_arr, idx_arr); + BitonicMerge::merge( + val_arr + arr_len / 2, idx_arr + arr_len / 2); + } +}; + +template +struct BitonicSort { + __device__ static void sort(T* __restrict__ val_arr, + idxT* __restrict__ idx_arr) { + static_assert(isPowerOf2(size)); + static_assert(size >= 2 * WARP_SIZE); + constexpr int arr_len = size / WARP_SIZE; + + BitonicSort::sort(val_arr, idx_arr); + BitonicSort::sort( + val_arr + arr_len / 2, idx_arr + arr_len / 2); + BitonicMerge::merge( + val_arr, idx_arr); + } +}; + +template +struct BitonicSort<32, ascending, T, idxT, is_stable> { + __device__ static void sort(T* __restrict__ val_arr, + idxT* __restrict__ idx_arr) { + int const lane = threadIdx.x % WARP_SIZE; + + // ascending doesn't matter before merging since all we need is a bitonic + // sequence + for (int stage = 0; stage < 4; ++stage) { + for (int stride = (1 << stage); stride > 0; stride /= 2) { + bool reverse = (lane >> stage) & 2; + bool is_second = lane & stride; + + T other = __shfl_xor_sync(FULL_WARP_MASK, *val_arr, stride); + idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, *idx_arr, stride); + + bool is_better; + if constexpr (is_stable) { + if constexpr (ascending) { + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr < other_idx))) != + (reverse != is_second); + } else { + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr > other_idx))) != + (reverse != is_second); + } + } else { + is_better = (*val_arr != other && + (*val_arr > other) != (reverse != is_second)); + } + if (is_better) { + *val_arr = other; + *idx_arr = other_idx; + } + } + } + + BitonicMerge<32, ascending, ascending, T, idxT, is_stable>::merge(val_arr, + idx_arr); + } +}; + +template +struct BitonicMerge<32, ascending, reverse, T, idxT, is_stable> { + __device__ static void merge(T* __restrict__ val_arr, + idxT* __restrict__ idx_arr) { + int const lane = threadIdx.x % WARP_SIZE; + for (int stride = WARP_SIZE / 2; stride > 0; stride /= 2) { + bool is_second = lane & stride; + T& val = *val_arr; + T other = __shfl_xor_sync(FULL_WARP_MASK, val, stride); + idxT& idx = *idx_arr; + idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, idx, stride); + + bool is_better; + if constexpr (is_stable) { + if constexpr (ascending) { + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr < other_idx))) == + (reverse != is_second); // for min + } else { + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr > other_idx))) == + (reverse != is_second); // for max + } + } else { + is_better = + (val != other && ((val > other) == (ascending != is_second))); + } + + if (is_better) { + val = other; + idx = other_idx; + } + } + } +}; + +template +class WarpSort { + public: + __device__ WarpSort(idxT k, T dummy) + : lane_(threadIdx.x % WARP_SIZE), k_(k), dummy_(dummy) { + static_assert(capacity >= WARP_SIZE && isPowerOf2(capacity)); + + for (int i = 0; i < max_arr_len_; ++i) { + val_arr_[i] = dummy_; + idx_arr_[i] = 0; + } + } + + // load and merge k sorted values + __device__ void load_sorted(T const* __restrict__ in, + idxT const* __restrict__ in_idx, idxT start) { + idxT idx = start + WARP_SIZE - 1 - lane_; + for (int i = max_arr_len_ - 1; i >= 0; --i, idx += WARP_SIZE) { + if (idx < start + k_) { + T t = in[idx]; + bool is_better; + if constexpr (is_stable) { + is_better = + is_better_than(t, val_arr_[i], in_idx[idx], idx_arr_[i]); + } else { + is_better = is_better_than(t, val_arr_[i]); + } + if (is_better) { + val_arr_[i] = t; + idx_arr_[i] = in_idx[idx]; + } + } + } + + BitonicMerge::merge( + val_arr_, idx_arr_); + } + + __device__ void dump(T* __restrict__ out, idxT* __restrict__ out_idx) const { + for (int i = 0; i < max_arr_len_; ++i) { + idxT out_i = i * WARP_SIZE + lane_; + if (out_i < k_) { + out[out_i] = val_arr_[i]; + out_idx[out_i] = idx_arr_[i]; + } + } + } + + __device__ void dumpIdx(idxT* __restrict__ out_idx) const { + for (int i = 0; i < max_arr_len_; ++i) { + idxT out_i = i * WARP_SIZE + lane_; + if (out_i < k_) { + out_idx[out_i] = idx_arr_[i]; + } + } + } + + protected: + static constexpr int max_arr_len_ = capacity / WARP_SIZE; + + T val_arr_[max_arr_len_]; + idxT idx_arr_[max_arr_len_]; + + int const lane_; + idxT const k_; + T const dummy_; + +}; // end class WarpSort + +template +class WarpSelect : public WarpSort { + public: + __device__ WarpSelect(idxT k, T dummy) + : WarpSort(k, dummy), + k_th_(dummy), + k_th_lane_((k - 1) % WARP_SIZE) { + extern __shared__ char smem_buf[]; // extern __shared__ T smem_buf[]; + + int const num_of_warp = blockDim.x / WARP_SIZE; + int const warp_id = threadIdx.x / WARP_SIZE; + val_smem_ = reinterpret_cast(smem_buf); + val_smem_ += warp_id * WARP_SIZE; + idx_smem_ = reinterpret_cast( + smem_buf + + round_up_to_multiple_of<256>(num_of_warp * sizeof(T) * WARP_SIZE)); + idx_smem_ += warp_id * WARP_SIZE; + } + + __device__ void add(T const* in, idxT start, idxT end) { + idxT const end_for_fullwarp = + round_up_to_multiple_of(end - start) + start; + for (idxT i = start + lane_; i < end_for_fullwarp; i += WARP_SIZE) { + T val = (i < end) ? in[i] : dummy_; + add(val, i); + } + } + + __device__ void add(T val, idxT idx) { + bool do_add; + if constexpr (is_stable) { + do_add = is_better_than(val, k_th_, idx, k_th_idx_); + } else { + do_add = is_better_than(val, k_th_); + } + + uint32_t mask = __ballot_sync(FULL_WARP_MASK, do_add); + if (mask == 0) { + return; + } + + int pos = smem_buf_len_ + __popc(mask & ((0x1u << lane_) - 1)); + if (do_add && pos < WARP_SIZE) { + val_smem_[pos] = val; + idx_smem_[pos] = idx; + do_add = false; + } + smem_buf_len_ += __popc(mask); + if (smem_buf_len_ >= WARP_SIZE) { + __syncwarp(); + merge_buf_(val_smem_[lane_], idx_smem_[lane_]); + smem_buf_len_ -= WARP_SIZE; + } + if (do_add) { + pos -= WARP_SIZE; + val_smem_[pos] = val; + idx_smem_[pos] = idx; + } + __syncwarp(); + } + + __device__ void done() { + if (smem_buf_len_) { + T val = (lane_ < smem_buf_len_) ? val_smem_[lane_] : dummy_; + idxT idx = (lane_ < smem_buf_len_) ? idx_smem_[lane_] : 0; + merge_buf_(val, idx); + } + + // after done(), smem is used for merging results among warps + __syncthreads(); + } + + private: + __device__ void set_k_th_() { + k_th_ = __shfl_sync(FULL_WARP_MASK, val_arr_[max_arr_len_ - 1], k_th_lane_); + if constexpr (is_stable) { + k_th_idx_ = + __shfl_sync(FULL_WARP_MASK, idx_arr_[max_arr_len_ - 1], k_th_lane_); + } + } + + __device__ void merge_buf_(T val, idxT idx) { + BitonicSort::sort(&val, &idx); + + T& old = val_arr_[max_arr_len_ - 1]; + + bool is_better; + if constexpr (is_stable) { + is_better = + is_better_than(val, old, idx, idx_arr_[max_arr_len_ - 1]); + } else { + is_better = is_better_than(val, old); + } + + if (is_better) { + old = val; + idx_arr_[max_arr_len_ - 1] = idx; + } + + BitonicMerge::merge( + val_arr_, idx_arr_); + + set_k_th_(); + } + + using WarpSort::max_arr_len_; + using WarpSort::val_arr_; + using WarpSort::idx_arr_; + using WarpSort::lane_; + using WarpSort::k_; + using WarpSort::dummy_; + + T* val_smem_; + idxT* idx_smem_; + int smem_buf_len_ = 0; + + T k_th_; + idxT k_th_idx_; + int const k_th_lane_; +}; // end class WarpSelect +} // namespace warp_topk + +template +__device__ inline T_OUT cuda_cast(T_IN val) { + return val; +} + +template <> +__device__ inline float cuda_cast(__nv_bfloat16 val) { + return __bfloat162float(val); +} + +template +__device__ void topk_with_k2(T* output, T const* input, + cg::thread_block_tile<32> const& tile, + int32_t const lane_id, + int const num_experts_per_group) { + // Get the top2 per thread + T largest = -INFINITY; + T second_largest = -INFINITY; + + if (num_experts_per_group > WARP_SIZE) { + for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { + T value = input[i]; + if (value > largest) { + second_largest = largest; + largest = value; + } else if (value > second_largest) { + second_largest = value; + } + } + } else { + for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { + largest = input[i]; + } + } + + __syncwarp(); // Ensure all threads have valid data before reduction + // Get the top2 warpwise + T max1 = cg::reduce(tile, largest, cg::greater()); + + T max2 = max1; + bool equal_to_max1 = (max1 == largest); + + int count_max1 = __popc(__ballot_sync(FULL_WARP_MASK, equal_to_max1)); + + if (count_max1 == 1) { + largest = (largest == max1) ? second_largest : largest; + max2 = cg::reduce(tile, largest, cg::greater()); + } + + if (lane_id == 0) { + *output = max1 + max2; + } +} + +template +__global__ void topk_with_k2_kernel(T* output, T* input, + int64_t const num_tokens, + int64_t const num_cases, + int64_t const n_group, + int64_t const num_experts_per_group) { + int32_t warp_id = threadIdx.x / WARP_SIZE; + int32_t lane_id = threadIdx.x % WARP_SIZE; + + int32_t case_id = blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; + if (case_id < num_cases) { + input += case_id * num_experts_per_group; + output += case_id; + + cg::thread_block block = cg::this_thread_block(); + cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block); + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); +#endif + topk_with_k2(output, input, tile, lane_id, num_experts_per_group); + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif +} + +template +__global__ void group_idx_and_topk_idx_kernel( + T* scores, T const* group_scores, T* topk_values, IdxT* topk_indices, + T* scores_with_bias, int64_t const num_tokens, int64_t const n_group, + int64_t const topk_group, int64_t const topk, int64_t const num_experts, + int64_t const num_experts_per_group, bool renormalize, + double routed_scaling_factor) { + int32_t warp_id = threadIdx.x / WARP_SIZE; + int32_t lane_id = threadIdx.x % WARP_SIZE; + int32_t case_id = + blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; // one per token + scores_with_bias += case_id * num_experts; + scores += case_id * num_experts; + group_scores += case_id * n_group; + topk_values += case_id * topk; + topk_indices += case_id * topk; + + int32_t align_num_experts_per_group = + warp_topk::round_up_to_multiple_of(num_experts_per_group); + + cg::thread_block block = cg::this_thread_block(); + cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block); + + extern __shared__ char smem_buf[]; // NOTE: reuse the shared memory here to + // store the target topk idx + int32_t* s_topk_idx = reinterpret_cast(smem_buf); + T* s_topk_value = + reinterpret_cast(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) + + warp_id * topk; + s_topk_idx += warp_id * topk; + + T value = cuda::std::numeric_limits::min(); + T topk_group_value = cuda::std::numeric_limits::min(); + int32_t num_equalto_topkth_group; + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); // I think all prolog can be put before + // acqbulk because it's ptr arithmetic +#endif + + if (case_id < num_tokens) { + // calculate group_idx + int32_t target_num_min = WARP_SIZE - n_group + topk_group; + if (lane_id < n_group && + (isfinite(cuda_cast( + group_scores[lane_id])))) // The check is necessary to avoid + // abnormal input + { + value = group_scores[lane_id]; + } + + int count_equal_to_top_value = WARP_SIZE - n_group; + int pre_count_equal_to_top_value = 0; + // Use loop to find the largset top_group + while (count_equal_to_top_value < target_num_min) { + __syncwarp(); // Ensure all threads have valid data before reduction + topk_group_value = cg::reduce(tile, value, cg::greater()); + if (value == topk_group_value) { + value = cuda::std::numeric_limits::min(); + } + pre_count_equal_to_top_value = count_equal_to_top_value; + count_equal_to_top_value = __popc(__ballot_sync( + FULL_WARP_MASK, (value == cuda::std::numeric_limits::min()))); + } + num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value; + } + __syncthreads(); + + warp_topk::WarpSelect + queue((int32_t)topk, -INFINITY); + + int count_equalto_topkth_group = 0; + bool if_proceed_next_topk = + (topk_group_value != cuda::std::numeric_limits::min()); + if (case_id < num_tokens && if_proceed_next_topk) { + for (int i_group = 0; i_group < n_group; i_group++) { + if ((group_scores[i_group] > topk_group_value) || + ((group_scores[i_group] == topk_group_value) && + (count_equalto_topkth_group < num_equalto_topkth_group))) { + int32_t offset = i_group * num_experts_per_group; + for (int32_t i = lane_id; i < align_num_experts_per_group; + i += WARP_SIZE) { + T candidates = + (i < num_experts_per_group) && isfinite(cuda_cast( + scores_with_bias[offset + i])) + ? scores_with_bias[offset + i] + : cuda::std::numeric_limits::min(); + queue.add(candidates, offset + i); + } + if (group_scores[i_group] == topk_group_value) { + count_equalto_topkth_group++; + } + } + } + queue.done(); + __syncwarp(); + // Get the topk_idx + queue.dumpIdx(s_topk_idx); + __syncwarp(); + } + + // Load the valid score value + // Calculate the summation + float topk_sum = 1e-20; + if (case_id < num_tokens && if_proceed_next_topk) { + for (int i = lane_id; + i < warp_topk::round_up_to_multiple_of(topk); + i += WARP_SIZE) { + T value = + i < topk + ? scores[s_topk_idx[i]] + : cuda_cast(0.0f); // Load the valid value of expert + if (i < topk) { + s_topk_value[i] = value; + } + topk_sum += reduce(tile, cuda_cast(value), cg::plus()); + } + } + + __syncthreads(); + + if (case_id < num_tokens) { + if (if_proceed_next_topk) { + for (int i = lane_id; i < topk; i += WARP_SIZE) { + float value; + if (renormalize) { + value = cuda_cast(s_topk_value[i]) / topk_sum * + routed_scaling_factor; + } else { + value = cuda_cast(s_topk_value[i]) * routed_scaling_factor; + } + topk_indices[i] = s_topk_idx[i]; + topk_values[i] = cuda_cast(value); + } + } else { + for (int i = lane_id; i < topk; i += WARP_SIZE) { + topk_indices[i] = i; + topk_values[i] = cuda_cast(1.0f / topk); + } + } + // Note: when if_proceed_next_topk==false, choose the first 8 experts as the + // default result. + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif +} + +template +void invokeNoAuxTc(T* scores, T* group_scores, T* topk_values, + IdxT* topk_indices, T* scores_with_bias, + int64_t const num_tokens, int64_t const num_experts, + int64_t const n_group, int64_t const topk_group, + int64_t const topk, bool const renormalize, + double const routed_scaling_factor, bool enable_pdl = false, + cudaStream_t const stream = 0) { + int64_t num_cases = num_tokens * n_group; + int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1; + auto* kernel_instance1 = &topk_with_k2_kernel; + cudaLaunchConfig_t config; + config.gridDim = topk_with_k2_num_blocks; + config.blockDim = BLOCK_SIZE; + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; + config.numAttrs = 1; + config.attrs = attrs; + cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores_with_bias, + num_tokens, num_cases, n_group, num_experts / n_group); + + int64_t topk_with_k_group_num_blocks = + (num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1; + size_t dynamic_smem_in_bytes = + warp_topk::calc_smem_size_for_block_wide(NUM_WARPS_PER_BLOCK, + topk); + auto* kernel_instance2 = &group_idx_and_topk_idx_kernel; + config.gridDim = topk_with_k_group_num_blocks; + config.blockDim = BLOCK_SIZE; + config.dynamicSmemBytes = dynamic_smem_in_bytes; + config.stream = stream; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; + config.numAttrs = 1; + config.attrs = attrs; + cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, + topk_values, topk_indices, scores_with_bias, num_tokens, + n_group, topk_group, topk, num_experts, + num_experts / n_group, renormalize, routed_scaling_factor); +} + +#define INSTANTIATE_NOAUX_TC(T, IdxT) \ + template void invokeNoAuxTc( \ + T * scores, T * group_scores, T * topk_values, IdxT * topk_indices, \ + T * scores_with_bias, int64_t const num_tokens, \ + int64_t const num_experts, int64_t const n_group, \ + int64_t const topk_group, int64_t const topk, bool const renormalize, \ + double const routed_scaling_factor, bool enable_pdl, \ + cudaStream_t const stream); + +INSTANTIATE_NOAUX_TC(float, int32_t); +INSTANTIATE_NOAUX_TC(half, int32_t); +INSTANTIATE_NOAUX_TC(__nv_bfloat16, int32_t); +} // end namespace moe +} // namespace vllm + +std::tuple grouped_topk( + torch::Tensor const& scores, torch::Tensor const& scores_with_bias, + int64_t n_group, int64_t topk_group, int64_t topk, bool renormalize, + double routed_scaling_factor) { + auto data_type = scores_with_bias.scalar_type(); + auto input_size = scores_with_bias.sizes(); + int64_t num_tokens = input_size[0]; + int64_t num_experts = input_size[1]; + TORCH_CHECK(input_size.size() == 2, "scores_with_bias must be a 2D Tensor"); + TORCH_CHECK(num_experts % n_group == 0, + "num_experts should be divisible by n_group"); + TORCH_CHECK(n_group <= 32, + "n_group should be smaller than or equal to 32 for now"); + TORCH_CHECK(topk <= 32, "topk should be smaller than or equal to 32 for now"); + + torch::Tensor group_scores = torch::empty( + {num_tokens, n_group}, torch::dtype(data_type).device(torch::kCUDA)); + torch::Tensor topk_values = torch::empty( + {num_tokens, topk}, torch::dtype(data_type).device(torch::kCUDA)); + torch::Tensor topk_indices = torch::empty( + {num_tokens, topk}, torch::dtype(torch::kInt32).device(torch::kCUDA)); + + auto stream = c10::cuda::getCurrentCUDAStream(scores_with_bias.get_device()); + + switch (data_type) { + case torch::kFloat16: + // Handle Float16 + vllm::moe::invokeNoAuxTc( + reinterpret_cast(scores.mutable_data_ptr()), + reinterpret_cast(group_scores.mutable_data_ptr()), + reinterpret_cast(topk_values.mutable_data_ptr()), + reinterpret_cast(topk_indices.mutable_data_ptr()), + reinterpret_cast(scores_with_bias.data_ptr()), num_tokens, + num_experts, n_group, topk_group, topk, renormalize, + routed_scaling_factor, false, stream); + break; + case torch::kFloat32: + // Handle Float32 + vllm::moe::invokeNoAuxTc( + reinterpret_cast(scores.mutable_data_ptr()), + reinterpret_cast(group_scores.mutable_data_ptr()), + reinterpret_cast(topk_values.mutable_data_ptr()), + reinterpret_cast(topk_indices.mutable_data_ptr()), + reinterpret_cast(scores_with_bias.data_ptr()), num_tokens, + num_experts, n_group, topk_group, topk, renormalize, + routed_scaling_factor, false, stream); + break; + case torch::kBFloat16: + // Handle BFloat16 + vllm::moe::invokeNoAuxTc<__nv_bfloat16, int32_t>( + reinterpret_cast<__nv_bfloat16*>(scores.mutable_data_ptr()), + reinterpret_cast<__nv_bfloat16*>(group_scores.mutable_data_ptr()), + reinterpret_cast<__nv_bfloat16*>(topk_values.mutable_data_ptr()), + reinterpret_cast(topk_indices.mutable_data_ptr()), + reinterpret_cast<__nv_bfloat16*>(scores_with_bias.data_ptr()), + num_tokens, num_experts, n_group, topk_group, topk, renormalize, + routed_scaling_factor, false, stream); + break; + default: + // Handle other data types + throw std::invalid_argument( + "Invalid dtype, only supports float16, float32, and bfloat16"); + break; + } + return {topk_values, topk_indices}; +} diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 661730c96867..92fc280b362b 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -22,6 +22,11 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, torch::Tensor num_tokens_post_pad, int64_t top_k, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, int64_t BLOCK_SIZE_K, int64_t bit); + +std::tuple grouped_topk( + torch::Tensor const& scores, torch::Tensor const& scores_with_bias, + int64_t n_group, int64_t topk_group, int64_t topk, bool renormalize, + double routed_scaling_factor); #endif bool moe_permute_unpermute_supported(); diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 7e49f68f6243..8f33d6cd666f 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -78,6 +78,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "output_tensor) -> ()"); m.impl("shuffle_rows", torch::kCUDA, &shuffle_rows); + // Apply grouped topk routing to select experts. + m.def( + "grouped_topk(Tensor scores, Tensor scores_with_bias, int n_group, int " + "topk_group, int topk, bool renormalize, float " + "routed_scaling_factor) -> (Tensor, Tensor)"); + m.impl("grouped_topk", torch::kCUDA, &grouped_topk); #endif } diff --git a/tests/kernels/moe/test_grouped_topk.py b/tests/kernels/moe/test_grouped_topk.py new file mode 100644 index 000000000000..646e763194fd --- /dev/null +++ b/tests/kernels/moe/test_grouped_topk.py @@ -0,0 +1,76 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for the MoE grouped topk kernel + +Run `pytest tests/kernels/moe/test_grouped_topk.py`. +""" +import pytest +import torch + +from vllm.model_executor.layers.fused_moe.fused_moe import (fused_grouped_topk, + grouped_topk) +from vllm.platforms import current_platform + + +@pytest.mark.skipif(not current_platform.is_cuda(), + reason="This test is skipped on non-CUDA platform.") +@pytest.mark.parametrize("n_token", [1, 33, 64]) +@pytest.mark.parametrize("n_hidden", [1024, 2048]) +@pytest.mark.parametrize("n_expert", [16]) +@pytest.mark.parametrize("topk", [2]) +@pytest.mark.parametrize("renormalize", [True, False]) +@pytest.mark.parametrize("num_expert_group", [8]) +@pytest.mark.parametrize("topk_group", [2]) +@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"]) +@pytest.mark.parametrize("routed_scaling_factor", [1.0, 2.5]) +@pytest.mark.parametrize("dtype", + [torch.float16, torch.bfloat16, torch.float32]) +def test_grouped_topk(monkeypatch: pytest.MonkeyPatch, n_token: int, + n_hidden: int, n_expert: int, topk: int, + renormalize: bool, num_expert_group: int, + topk_group: int, scoring_func: str, + routed_scaling_factor: float, dtype: torch.dtype): + current_platform.seed_everything(0) + hidden_states = torch.randn((n_token, n_hidden), + dtype=dtype, + device="cuda") + gating_output = torch.randn((n_token, n_expert), + dtype=dtype, + device="cuda") + e_score_correction_bias = torch.randn((n_expert, ), + dtype=torch.float32, + device="cuda") + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0") + baseline_topk_weights, baseline_topk_ids = grouped_topk( + hidden_states=hidden_states, + gating_output=gating_output, + topk=topk, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias) + + test_topk_weights, test_topk_ids = fused_grouped_topk( + hidden_states=hidden_states, + gating_output=gating_output, + topk=topk, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias) + + if renormalize: + torch.testing.assert_close(baseline_topk_weights, + test_topk_weights, + atol=2e-2, + rtol=0) + torch.testing.assert_close(baseline_topk_ids, + test_topk_ids, + atol=0, + rtol=0) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 3e3b43ce2abe..054dc9d985a4 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1502,6 +1502,17 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, gating_output) +def grouped_topk(scores: torch.Tensor, scores_with_bias: torch.Tensor, + num_expert_group: int, topk_group: int, topk: int, + renormalize: bool, routed_scaling_factor: float): + if not current_platform.is_cuda(): + raise NotImplementedError("The fused grouped_topk kernel is only " + "available on CUDA platforms") + return torch.ops._moe_C.grouped_topk(scores, scores_with_bias, + num_expert_group, topk_group, topk, + renormalize, routed_scaling_factor) + + def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor], b_qweight: torch.Tensor, b_bias: Optional[torch.Tensor], diff --git a/vllm/envs.py b/vllm/envs.py index 5d0e972f43ad..1c9c4cdde800 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -131,6 +131,7 @@ if TYPE_CHECKING: VLLM_USE_DEEP_GEMM: bool = False VLLM_USE_DEEP_GEMM_E8M0: bool = True VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False + VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True VLLM_USE_FLASHINFER_MOE_FP8: bool = False VLLM_USE_FLASHINFER_MOE_FP4: bool = False VLLM_FLASHINFER_MOE_BACKEND: str = "throughput" @@ -963,6 +964,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_SKIP_DEEP_GEMM_WARMUP": lambda: bool(int(os.getenv("VLLM_SKIP_DEEP_GEMM_WARMUP", "0"))), + # Whether to use fused grouped_topk used for MoE expert selection. + "VLLM_USE_FUSED_MOE_GROUPED_TOPK": + lambda: bool(int(os.getenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "1"))), + # Allow use of FlashInfer MoE kernels for fused moe ops. "VLLM_USE_FLASHINFER_MOE_FP8": lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP8", "0"))), @@ -1229,6 +1234,7 @@ def compute_hash() -> str: "VLLM_DISABLED_KERNELS", "VLLM_USE_DEEP_GEMM", "VLLM_USE_TRTLLM_FP4_GEMM", + "VLLM_USE_FUSED_MOE_GROUPED_TOPK", "VLLM_USE_FLASHINFER_MOE_FP8", "VLLM_USE_FLASHINFER_MOE_FP4", "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 02b7b65f4a02..84dafcf00d82 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -949,8 +949,23 @@ def grouped_topk( num_expert_group: int = 0, topk_group: int = 0, scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None + routed_scaling_factor: float = 1.0, + e_score_correction_bias: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: + if envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK and \ + current_platform.is_cuda() and \ + num_expert_group <= 32 and topk <= 32 and \ + e_score_correction_bias is not None: + return fused_grouped_topk( + hidden_states=hidden_states, + gating_output=gating_output, + topk=topk, + renormalize=renormalize, + e_score_correction_bias=e_score_correction_bias, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor) assert hidden_states.size(0) == gating_output.size(0), ( "Number of tokens mismatch") @@ -996,9 +1011,38 @@ def grouped_topk( if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + topk_weights = topk_weights * routed_scaling_factor return topk_weights.to(torch.float32), topk_ids.to(torch.int32) +def fused_grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + e_score_correction_bias: torch.Tensor, + num_expert_group: int = 0, + topk_group: int = 0, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, +) -> tuple[torch.Tensor, torch.Tensor]: + assert hidden_states.size(0) == gating_output.size(0), ( + "Number of tokens mismatch") + + if scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + scores_with_bias = scores + e_score_correction_bias.unsqueeze(0) + topk_values, topk_indices = ops.grouped_topk( + scores, scores_with_bias.to(scores.dtype), num_expert_group, + topk_group, topk, renormalize, routed_scaling_factor) + return topk_values.to(torch.float32), topk_indices.to(torch.int32) + + def get_config_dtype_str( dtype: torch.dtype, use_int4_w4a16: Optional[bool] = False,