From 7f89ed248fef01098c7ce4bebb197b462eb15bc3 Mon Sep 17 00:00:00 2001 From: shixianc <49539556+shixianc@users.noreply.github.com> Date: Fri, 15 Aug 2025 14:02:12 -0700 Subject: [PATCH] [Fix] enable swap_ab for pplx problem size computation (#22991) Signed-off-by: Shixian Cui Co-authored-by: Shixian Cui --- .../quantization/cutlass_w8a8/moe/moe_data.cu | 45 +++++++++++++------ 1 file changed, 32 insertions(+), 13 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu index 857cca1e82df7..100f485084444 100644 --- a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu +++ b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu @@ -161,6 +161,7 @@ void get_cutlass_moe_mm_data_caller( topk_ids.size(1)); } +template __global__ void compute_pplx_data(int32_t* expert_offsets, int32_t* problem_sizes1, int32_t* problem_sizes2, @@ -168,14 +169,23 @@ __global__ void compute_pplx_data(int32_t* expert_offsets, const int padded_m, const int n, const int k) { int expert_idx = threadIdx.x; - expert_offsets[expert_idx] = expert_idx * padded_m; - problem_sizes1[expert_idx * 3] = expert_num_tokens[expert_idx]; - problem_sizes1[expert_idx * 3 + 1] = 2 * n; - problem_sizes1[expert_idx * 3 + 2] = k; - problem_sizes2[expert_idx * 3] = expert_num_tokens[expert_idx]; - problem_sizes2[expert_idx * 3 + 1] = k; - problem_sizes2[expert_idx * 3 + 2] = n; + + if constexpr (!SWAP_AB) { + problem_sizes1[expert_idx * 3] = expert_num_tokens[expert_idx]; + problem_sizes1[expert_idx * 3 + 1] = 2 * n; + problem_sizes1[expert_idx * 3 + 2] = k; + problem_sizes2[expert_idx * 3] = expert_num_tokens[expert_idx]; + problem_sizes2[expert_idx * 3 + 1] = k; + problem_sizes2[expert_idx * 3 + 2] = n; + } else { + problem_sizes1[expert_idx * 3] = 2 * n; + problem_sizes1[expert_idx * 3 + 1] = expert_num_tokens[expert_idx]; + problem_sizes1[expert_idx * 3 + 2] = k; + problem_sizes2[expert_idx * 3] = k; + problem_sizes2[expert_idx * 3 + 1] = expert_num_tokens[expert_idx]; + problem_sizes2[expert_idx * 3 + 2] = n; + } } void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets, @@ -187,10 +197,19 @@ void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets, const int64_t n, const int64_t k) { auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index()); - compute_pplx_data<<<1, num_local_experts, 0, stream>>>( - static_cast(expert_offsets.data_ptr()), - static_cast(problem_sizes1.data_ptr()), - static_cast(problem_sizes2.data_ptr()), - static_cast(expert_num_tokens.data_ptr()), padded_m, n, - k); + if (num_local_experts * padded_m > SWAP_AB_THRESHOLD) { + compute_pplx_data<<<1, num_local_experts, 0, stream>>>( + static_cast(expert_offsets.data_ptr()), + static_cast(problem_sizes1.data_ptr()), + static_cast(problem_sizes2.data_ptr()), + static_cast(expert_num_tokens.data_ptr()), padded_m, n, + k); + } else { + compute_pplx_data<<<1, num_local_experts, 0, stream>>>( + static_cast(expert_offsets.data_ptr()), + static_cast(problem_sizes1.data_ptr()), + static_cast(problem_sizes2.data_ptr()), + static_cast(expert_num_tokens.data_ptr()), padded_m, n, + k); + } } \ No newline at end of file