From 61059bee40511b6f6c044053cf921da81cf89985 Mon Sep 17 00:00:00 2001 From: Chiyue Wei <92623189+dubcyfor3@users.noreply.github.com> Date: Thu, 5 Jun 2025 09:48:26 -0700 Subject: [PATCH] [Hardware][NVIDIA] FP4 MoE kernel optimization (#19110) Signed-off-by: Chiyue Wei Co-authored-by: Chiyue Wei --- .../kernels/benchmark_cutlass_fp4_moe.py | 2 +- csrc/moe/moe_ops.h | 6 +- csrc/moe/moe_permute_unpermute_op.cu | 56 +++++++++++++++++++ csrc/moe/permute_unpermute_kernels/dispatch.h | 18 ++++-- csrc/moe/torch_bindings.cpp | 6 ++ csrc/ops.h | 3 +- .../quantization/cutlass_w8a8/moe/moe_data.cu | 36 ++++++++++-- .../cutlass_w8a8/scaled_mm_entry.cu | 9 ++- csrc/torch_bindings.cpp | 2 +- tests/kernels/moe/test_nvfp4_moe.py | 5 +- vllm/_custom_ops.py | 45 +++++++++++---- .../layers/fused_moe/cutlass_moe.py | 15 ++--- 12 files changed, 165 insertions(+), 38 deletions(-) diff --git a/benchmarks/kernels/benchmark_cutlass_fp4_moe.py b/benchmarks/kernels/benchmark_cutlass_fp4_moe.py index 3383fb78872a..35c20ee41b9a 100644 --- a/benchmarks/kernels/benchmark_cutlass_fp4_moe.py +++ b/benchmarks/kernels/benchmark_cutlass_fp4_moe.py @@ -91,7 +91,7 @@ def bench_run( score = torch.randn((m, num_experts), device=device, dtype=dtype) - topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) + topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False) quant_blocksize = 16 w1_blockscale = torch.empty( diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 8fda434d452f..c4faef731060 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -30,4 +30,8 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, int64_t BLOCK_SIZE_K, int64_t bit); #endif -bool moe_permute_unpermute_supported(); \ No newline at end of file +bool moe_permute_unpermute_supported(); + +void shuffle_rows(const torch::Tensor& input_tensor, + const torch::Tensor& dst2src_map, + torch::Tensor& output_tensor); \ No newline at end of file diff --git a/csrc/moe/moe_permute_unpermute_op.cu b/csrc/moe/moe_permute_unpermute_op.cu index 9a7465261abf..68f429fac18a 100644 --- a/csrc/moe/moe_permute_unpermute_op.cu +++ b/csrc/moe/moe_permute_unpermute_op.cu @@ -130,6 +130,62 @@ void moe_unpermute( }); } +template +__global__ void shuffleInputRowsKernel(const T* input, + const int32_t* dst2src_map, T* output, + int64_t num_src_rows, + int64_t num_dst_rows, int64_t num_cols) { + int64_t dest_row_idx = blockIdx.x; + int64_t const source_row_idx = dst2src_map[dest_row_idx]; + + if (blockIdx.x < num_dst_rows) { + // Load 128-bits per thread + constexpr int64_t ELEM_PER_THREAD = 128 / sizeof(T) / 8; + using DataElem = cutlass::Array; + + // Duplicate and permute rows + auto const* source_row_ptr = + reinterpret_cast(input + source_row_idx * num_cols); + auto* dest_row_ptr = + reinterpret_cast(output + dest_row_idx * num_cols); + + int64_t const start_offset = threadIdx.x; + int64_t const stride = blockDim.x; + int64_t const num_elems_in_col = num_cols / ELEM_PER_THREAD; + + for (int elem_index = start_offset; elem_index < num_elems_in_col; + elem_index += stride) { + dest_row_ptr[elem_index] = source_row_ptr[elem_index]; + } + } +} + +void shuffle_rows(const torch::Tensor& input_tensor, + const torch::Tensor& dst2src_map, + torch::Tensor& output_tensor) { + TORCH_CHECK(input_tensor.scalar_type() == output_tensor.scalar_type(), + "Input and output tensors must have the same data type"); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + int64_t const blocks = output_tensor.size(0); + int64_t const threads = 256; + int64_t const num_dest_rows = output_tensor.size(0); + int64_t const num_src_rows = input_tensor.size(0); + int64_t const num_cols = input_tensor.size(1); + + TORCH_CHECK(!(num_cols % (128 / sizeof(input_tensor.scalar_type()) / 8)), + "num_cols must be divisible by 128 / " + "sizeof(input_tensor.scalar_type()) / 8"); + + MOE_DISPATCH(input_tensor.scalar_type(), [&] { + shuffleInputRowsKernel<<>>( + reinterpret_cast(input_tensor.data_ptr()), + dst2src_map.data_ptr(), + reinterpret_cast(output_tensor.data_ptr()), num_src_rows, + num_dest_rows, num_cols); + }); +} + #else void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights, diff --git a/csrc/moe/permute_unpermute_kernels/dispatch.h b/csrc/moe/permute_unpermute_kernels/dispatch.h index 41932cdd85bc..d0f1ea4aded3 100644 --- a/csrc/moe/permute_unpermute_kernels/dispatch.h +++ b/csrc/moe/permute_unpermute_kernels/dispatch.h @@ -14,12 +14,13 @@ __VA_ARGS__(); \ break; \ } -#define MOE_DISPATCH_FLOAT_CASE(...) \ - MOE_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ - MOE_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ - MOE_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ - MOE_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__) \ - MOE_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) +#define MOE_DISPATCH_FLOAT_CASE(...) \ + MOE_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + MOE_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + MOE_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + MOE_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__) \ + MOE_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \ + MOE_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) #define MOE_DISPATCH(TYPE, ...) \ MOE_SWITCH(TYPE, MOE_DISPATCH_FLOAT_CASE(__VA_ARGS__)) @@ -39,6 +40,11 @@ template <> struct ScalarType2CudaType { using type = __nv_bfloat16; }; +// uint8 for packed fp4 +template <> +struct ScalarType2CudaType { + using type = uint8_t; +}; // #if __CUDA_ARCH__ >= 890 // fp8 diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 7d35ec79ead4..a74eb3720cf1 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -81,6 +81,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { m.def("moe_permute_unpermute_supported() -> bool"); m.impl("moe_permute_unpermute_supported", &moe_permute_unpermute_supported); + // Row shuffle for MoE + m.def( + "shuffle_rows(Tensor input_tensor, Tensor dst2src_map, Tensor! " + "output_tensor) -> ()"); + m.impl("shuffle_rows", torch::kCUDA, &shuffle_rows); + #endif } diff --git a/csrc/ops.h b/csrc/ops.h index 297f32b4a2a0..6905ef6e5911 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -248,7 +248,8 @@ void get_cutlass_moe_mm_data( const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, torch::Tensor& input_permutation, torch::Tensor& output_permutation, - const int64_t num_experts, const int64_t n, const int64_t k); + const int64_t num_experts, const int64_t n, const int64_t k, + const std::optional& blockscale_offsets); void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, diff --git a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu index 894727383a63..ac414e1bc0c0 100644 --- a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu +++ b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu @@ -45,6 +45,23 @@ __global__ void compute_expert_offsets( } } +__global__ void compute_expert_blockscale_offsets( + const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets, + int32_t* blockscale_offsets, int32_t* atomic_buffer, + const int num_experts) { + int32_t tot_offset = 0; + int32_t tot_offset_round = 0; + expert_offsets[0] = 0; + blockscale_offsets[0] = 0; + for (int i = 0; i < num_experts; ++i) { + atomic_buffer[i] = tot_offset; + tot_offset += problem_sizes1[i * 3]; + expert_offsets[i + 1] = tot_offset; + tot_offset_round += (problem_sizes1[i * 3] + (128 - 1)) / 128 * 128; + blockscale_offsets[i + 1] = tot_offset_round; + } +} + __global__ void compute_arg_sorts(const int* __restrict__ topk_ids, const int32_t* __restrict__ expert_offsets, int32_t* input_permutation, @@ -77,7 +94,8 @@ void get_cutlass_moe_mm_data_caller( const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, torch::Tensor& input_permutation, torch::Tensor& output_permutation, - const int64_t num_experts, const int64_t n, const int64_t k) { + const int64_t num_experts, const int64_t n, const int64_t k, + const std::optional& blockscale_offsets) { auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index()); auto options_int32 = torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device()); @@ -89,10 +107,18 @@ void get_cutlass_moe_mm_data_caller( static_cast(problem_sizes1.data_ptr()), static_cast(problem_sizes2.data_ptr()), static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), n, k); - compute_expert_offsets<<<1, 1, 0, stream>>>( - static_cast(problem_sizes1.data_ptr()), - static_cast(expert_offsets.data_ptr()), - static_cast(atomic_buffer.data_ptr()), num_experts); + if (blockscale_offsets.has_value()) { + compute_expert_blockscale_offsets<<<1, 1, 0, stream>>>( + static_cast(problem_sizes1.data_ptr()), + static_cast(expert_offsets.data_ptr()), + static_cast(blockscale_offsets.value().data_ptr()), + static_cast(atomic_buffer.data_ptr()), num_experts); + } else { + compute_expert_offsets<<<1, 1, 0, stream>>>( + static_cast(problem_sizes1.data_ptr()), + static_cast(expert_offsets.data_ptr()), + static_cast(atomic_buffer.data_ptr()), num_experts); + } compute_arg_sorts<<>>( static_cast(topk_ids.data_ptr()), static_cast(expert_offsets.data_ptr()), diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index e9b408fbf2ee..ee93440b5754 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -54,7 +54,8 @@ void get_cutlass_moe_mm_data_caller( const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, torch::Tensor& input_permutation, torch::Tensor& output_permutation, - const int64_t num_experts, const int64_t n, const int64_t k); + const int64_t num_experts, const int64_t n, const int64_t k, + const std::optional& blockscale_offsets); #endif void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a, @@ -224,7 +225,8 @@ void get_cutlass_moe_mm_data( const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, torch::Tensor& input_permutation, torch::Tensor& output_permutation, - const int64_t num_experts, const int64_t n, const int64_t k) { + const int64_t num_experts, const int64_t n, const int64_t k, + const std::optional& blockscale_offsets) { // This function currently gets compiled only if we have a valid cutlass moe // mm to run it for. int32_t version_num = get_sm_version_num(); @@ -232,7 +234,8 @@ void get_cutlass_moe_mm_data( (defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM90) get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1, problem_sizes2, input_permutation, - output_permutation, num_experts, n, k); + output_permutation, num_experts, n, k, + blockscale_offsets); return; #endif TORCH_CHECK_NOT_IMPLEMENTED( diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 3fffaf290ad3..93916b7f94be 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -450,7 +450,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor! problem_sizes1, Tensor! problem_sizes2, " " Tensor! input_permutation, " " Tensor! output_permutation, int num_experts, " - " int n, int k) -> ()", + " int n, int k, Tensor? blockscale_offsets) -> ()", {stride_tag}); ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data); diff --git a/tests/kernels/moe/test_nvfp4_moe.py b/tests/kernels/moe/test_nvfp4_moe.py index be33200cc206..22482d9ca85a 100644 --- a/tests/kernels/moe/test_nvfp4_moe.py +++ b/tests/kernels/moe/test_nvfp4_moe.py @@ -80,7 +80,10 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, w2[expert], w2_gs[expert]) score = torch.randn((m, e), device="cuda", dtype=dtype) - topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) + topk_weights, topk_ids, _ = fused_topk(a, + score, + topk, + renormalize=False) a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 3282edf410b6..14404cd735ba 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -845,11 +845,16 @@ def cutlass_scaled_sparse_mm( return out -def get_cutlass_moe_mm_data( - topk_ids: torch.Tensor, expert_offsets: torch.Tensor, - problem_sizes1: torch.Tensor, problem_sizes2: torch.Tensor, - input_permutation: torch.Tensor, output_permutation: torch.Tensor, - num_experts: int, n: int, k: int): +def get_cutlass_moe_mm_data(topk_ids: torch.Tensor, + expert_offsets: torch.Tensor, + problem_sizes1: torch.Tensor, + problem_sizes2: torch.Tensor, + input_permutation: torch.Tensor, + output_permutation: torch.Tensor, + num_experts: int, + n: int, + k: int, + blockscale_offsets: Optional[torch.Tensor] = None): """ Prepare data necessary to perform CUTLASS grouped matrix multiplications used in CUTLASS-based fused MoE. @@ -867,12 +872,31 @@ def get_cutlass_moe_mm_data( before executing the MMs. - output_permutation: Permutation that must be used to shuffle the output after executing the MMs. + - blockscale_offsets: Optional argument passed for fp4 moe. Indices that + mark at which block scale index each expert begins + its computation. The number of block scale rows + computed with expert E is blockscale_offsets[E + 1] - + blockscale_offsets[E] """ return torch.ops._C.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1, problem_sizes2, input_permutation, output_permutation, - num_experts, n, k) + num_experts, n, k, + blockscale_offsets) + + +def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor): + """ + Shuffle and expand the input tensor according to the dst2src_map and store the result in output_tensor. + This is used in MoE to permute the input tensor before performing grouped matrix multiplications. + """ + num_tokens_permuted = dst2src_map.shape[0] + output_tensor = torch.empty((num_tokens_permuted, input_tensor.shape[1]), + device=input_tensor.device, + dtype=input_tensor.dtype) + torch.ops._moe_C.shuffle_rows(input_tensor, dst2src_map, output_tensor) + return output_tensor def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, @@ -1124,14 +1148,12 @@ def scaled_fp4_experts_quant( expert_offsets: torch.Tensor, blockscale_offsets: torch.Tensor, topk: int, - expert_map: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Quantize input tensor to FP4 and return quantized tensor and scale, for packed MoE Inputs. Args: - input: The input tensor to be quantized to FP4 - expert_map: The expert map tensor + input_tensor: The input tensor to be quantized to FP4 input_global_scale: A scalar scaling factor for the entire tensor. expert_offsets: The expert offsets tensor blockscale_offsets: The blockscale offsets tensor @@ -1143,14 +1165,13 @@ def scaled_fp4_experts_quant( assert input_tensor.ndim == 2, ( f'input.ndim needs to be == 2, but got {input_tensor.ndim}.') - input_tensor = input_tensor[ - expert_map] if expert_map is not None else input_tensor - m_numtopk, k = input_tensor.shape # Control the maximum number of tokens per expert supported by the # NVFP4 MoE Expert Quantization. This is used to prevent the kernel # from running out of memory. This value can also be increased to support # larger models. MAX_TOKENS_PER_EXPERT = envs.VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE + m_numtopk, k = input_tensor.shape + assert (m_numtopk <= MAX_TOKENS_PER_EXPERT * topk), ( f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT(" f"{MAX_TOKENS_PER_EXPERT})" diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index d827869d0538..e9446bc5fd2e 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -333,6 +333,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, num_topk = topk_ids.shape[1] expert_offsets = torch.empty((e + 1), dtype=torch.int32, device=device) + blockscale_offsets = torch.empty((e + 1), dtype=torch.int32, device=device) # Problem size: (num_experts, (m,2n,k)) problem_sizes1 = torch.empty((e, 3), dtype=torch.int32, device=device) # Problem size: (num_experts, (m,n,k)) @@ -344,12 +345,10 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, # problem shapes should have [m, n, k] # Note that problem sizes are based on logical number of elements. ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1, - problem_sizes2, a_map, c_map, e, n, k) + problem_sizes2, a_map, c_map, e, n, k, + blockscale_offsets) - tokens_per_expert = problem_sizes1[:, 0] - rounded_tokens_per_expert = (tokens_per_expert + (128 - 1)) // 128 * 128 - blockscale_offsets = torch.zeros(e + 1, dtype=torch.int32, device=device) - blockscale_offsets[1:] = torch.cumsum(rounded_tokens_per_expert, dim=0) + a = ops.shuffle_rows(a, a_map) rep_a_fp4, rep_a_blockscale = ops.scaled_fp4_experts_quant( a, @@ -357,7 +356,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, expert_offsets, blockscale_offsets, num_topk, - expert_map=a_map) + ) c1 = ops.cutlass_fp4_moe_mm(rep_a_fp4, w1_fp4, rep_a_blockscale, w1_blockscale, w1_alphas, problem_sizes1, @@ -378,6 +377,8 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, w2_alphas, problem_sizes2, expert_offsets[:-1], blockscale_offsets[:-1], out_dtype, device) del int_fp4, int_blockscale - out = (c2[c_map].view(m, num_topk, k) * + + c2 = ops.shuffle_rows(c2, c_map) + out = (c2.view(m, num_topk, k) * topk_weights.view(m, num_topk, 1).half()).sum(dim=1) return out.to(dtype=out_dtype)