diff --git a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu index 993c30c48c84a..857cca1e82df7 100644 --- a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu +++ b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu @@ -47,13 +47,12 @@ __global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids, __global__ void compute_expert_offsets( const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets, - int32_t* atomic_buffer, const int num_experts, const int topk_length) { + int32_t* atomic_buffer, const int num_experts, const bool swap_ab) { int32_t tot_offset = 0; expert_offsets[0] = 0; for (int i = 0; i < num_experts; ++i) { atomic_buffer[i] = tot_offset; - tot_offset += topk_length > SWAP_AB_THRESHOLD ? problem_sizes1[i * 3] - : problem_sizes1[i * 3 + 1]; + tot_offset += swap_ab ? problem_sizes1[i * 3 + 1] : problem_sizes1[i * 3]; expert_offsets[i + 1] = tot_offset; } } @@ -61,15 +60,14 @@ __global__ void compute_expert_offsets( __global__ void compute_expert_blockscale_offsets( const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets, int32_t* blockscale_offsets, int32_t* atomic_buffer, const int num_experts, - const int topk_length) { + const bool swap_ab) { int32_t tot_offset = 0; int32_t tot_offset_round = 0; expert_offsets[0] = 0; blockscale_offsets[0] = 0; for (int i = 0; i < num_experts; ++i) { - int32_t cur_offset = topk_length > SWAP_AB_THRESHOLD - ? problem_sizes1[i * 3] - : problem_sizes1[i * 3 + 1]; + int32_t cur_offset = + swap_ab ? problem_sizes1[i * 3 + 1] : problem_sizes1[i * 3]; atomic_buffer[i] = tot_offset; tot_offset += cur_offset; expert_offsets[i + 1] = tot_offset; @@ -119,15 +117,19 @@ void get_cutlass_moe_mm_data_caller( int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel()); - if (topk_ids.numel() > SWAP_AB_THRESHOLD) { - compute_problem_sizes<<>>( + // Swap-AB should be disabled for FP4 path + bool may_swap_ab = (!blockscale_offsets.has_value()) && + (topk_ids.numel() <= SWAP_AB_THRESHOLD); + + if (may_swap_ab) { + compute_problem_sizes<<>>( static_cast(topk_ids.data_ptr()), static_cast(problem_sizes1.data_ptr()), static_cast(problem_sizes2.data_ptr()), static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), n, k); } else { - compute_problem_sizes<<>>( + compute_problem_sizes<<>>( static_cast(topk_ids.data_ptr()), static_cast(problem_sizes1.data_ptr()), static_cast(problem_sizes2.data_ptr()), @@ -136,18 +138,19 @@ void get_cutlass_moe_mm_data_caller( } if (blockscale_offsets.has_value()) { + // fp4 path compute_expert_blockscale_offsets<<<1, 1, 0, stream>>>( static_cast(problem_sizes1.data_ptr()), static_cast(expert_offsets.data_ptr()), static_cast(blockscale_offsets.value().data_ptr()), static_cast(atomic_buffer.data_ptr()), num_experts, - topk_ids.numel()); + may_swap_ab); } else { compute_expert_offsets<<<1, 1, 0, stream>>>( static_cast(problem_sizes1.data_ptr()), static_cast(expert_offsets.data_ptr()), static_cast(atomic_buffer.data_ptr()), num_experts, - topk_ids.numel()); + may_swap_ab); } compute_arg_sorts<<>>( static_cast(topk_ids.data_ptr()),