mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-12 06:07:11 +08:00
[Perf] Add swap_ab to SM90 FP8 non-block CUTLASS moe grouped gemm (#20911)
Signed-off-by: Shixian Cui <shixian@amazon.com> Co-authored-by: Shixian Cui <shixian@amazon.com>
This commit is contained in:
parent
c7d8724e78
commit
5780121c95
@ -29,19 +29,36 @@ struct sm90_fp8_config_default {
|
|||||||
|
|
||||||
template <typename InType, typename OutType,
|
template <typename InType, typename OutType,
|
||||||
template <typename, typename, typename> typename Epilogue>
|
template <typename, typename, typename> typename Epilogue>
|
||||||
struct sm90_fp8_config_M16 {
|
struct sm90_fp8_config_M4 {
|
||||||
// M in [1, 16]
|
// M in [1, 4]
|
||||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
using KernelSchedule =
|
using KernelSchedule =
|
||||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
|
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
|
||||||
using EpilogueSchedule =
|
using EpilogueSchedule =
|
||||||
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||||
using TileShape = cute::Shape<cute::_64, cute::_64, cute::_128>;
|
using TileShape = cute::Shape<cute::_128, cute::_16, cute::_128>;
|
||||||
using ClusterShape = cute::Shape<cute::_1, cute::_4, cute::_1>;
|
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||||
|
|
||||||
using Cutlass3xGemm =
|
using Cutlass3xGemm =
|
||||||
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||||
KernelSchedule, EpilogueSchedule>;
|
KernelSchedule, EpilogueSchedule, true>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue>
|
||||||
|
struct sm90_fp8_config_M64 {
|
||||||
|
// M in (4, 64]
|
||||||
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
|
using KernelSchedule =
|
||||||
|
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
|
||||||
|
using EpilogueSchedule =
|
||||||
|
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||||
|
using TileShape = cute::Shape<cute::_128, cute::_16, cute::_256>;
|
||||||
|
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||||
|
|
||||||
|
using Cutlass3xGemm =
|
||||||
|
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||||
|
KernelSchedule, EpilogueSchedule, true>;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename InType, typename OutType,
|
template <typename InType, typename OutType,
|
||||||
@ -102,7 +119,9 @@ void run_cutlass_moe_mm_sm90(
|
|||||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||||
using Cutlass3xGemmK8192 = typename sm90_fp8_config_K8192<
|
using Cutlass3xGemmK8192 = typename sm90_fp8_config_K8192<
|
||||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||||
using Cutlass3xGemmM16 = typename sm90_fp8_config_M16<
|
using Cutlass3xGemmM4 = typename sm90_fp8_config_M4<
|
||||||
|
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||||
|
using Cutlass3xGemmM64 = typename sm90_fp8_config_M64<
|
||||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||||
using Cutlass3xGemmDefault = typename sm90_fp8_config_default<
|
using Cutlass3xGemmDefault = typename sm90_fp8_config_default<
|
||||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||||
@ -111,7 +130,18 @@ void run_cutlass_moe_mm_sm90(
|
|||||||
uint32_t const n = out_tensors.size(1);
|
uint32_t const n = out_tensors.size(1);
|
||||||
uint32_t const k = a_tensors.size(1);
|
uint32_t const k = a_tensors.size(1);
|
||||||
|
|
||||||
if (n >= 8192) {
|
// Use swap_ab for M <= 64 by default to reduce padding
|
||||||
|
if (m <= 4) {
|
||||||
|
cutlass_group_gemm_caller<Cutlass3xGemmM4>(
|
||||||
|
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||||
|
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||||
|
per_out_ch);
|
||||||
|
} else if (m <= 64) {
|
||||||
|
cutlass_group_gemm_caller<Cutlass3xGemmM64>(
|
||||||
|
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||||
|
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||||
|
per_out_ch);
|
||||||
|
} else if (n >= 8192) {
|
||||||
cutlass_group_gemm_caller<Cutlass3xGemmN8192>(
|
cutlass_group_gemm_caller<Cutlass3xGemmN8192>(
|
||||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||||
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||||
@ -121,11 +151,6 @@ void run_cutlass_moe_mm_sm90(
|
|||||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||||
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||||
per_out_ch);
|
per_out_ch);
|
||||||
} else if (m <= 16) {
|
|
||||||
cutlass_group_gemm_caller<Cutlass3xGemmM16>(
|
|
||||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
|
||||||
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
|
||||||
per_out_ch);
|
|
||||||
} else {
|
} else {
|
||||||
cutlass_group_gemm_caller<Cutlass3xGemmDefault>(
|
cutlass_group_gemm_caller<Cutlass3xGemmDefault>(
|
||||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||||
|
|||||||
@ -22,14 +22,23 @@ using ArchTag = cutlass::arch::Sm90;
|
|||||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||||
|
|
||||||
using LayoutA = cutlass::layout::RowMajor;
|
using LayoutA = cutlass::layout::RowMajor;
|
||||||
|
using LayoutA_Transpose =
|
||||||
|
typename cutlass::layout::LayoutTranspose<LayoutA>::type;
|
||||||
using LayoutB = cutlass::layout::ColumnMajor;
|
using LayoutB = cutlass::layout::ColumnMajor;
|
||||||
using LayoutC = cutlass::layout::RowMajor;
|
using LayoutB_Transpose =
|
||||||
|
typename cutlass::layout::LayoutTranspose<LayoutB>::type;
|
||||||
|
using LayoutD = cutlass::layout::RowMajor;
|
||||||
|
using LayoutD_Transpose =
|
||||||
|
typename cutlass::layout::LayoutTranspose<LayoutD>::type;
|
||||||
|
using LayoutC = LayoutD;
|
||||||
|
using LayoutC_Transpose = LayoutD_Transpose;
|
||||||
|
|
||||||
template <typename ElementAB_, typename ElementC_,
|
template <typename ElementAB_, typename ElementC_,
|
||||||
template <typename, typename, typename> typename Epilogue_,
|
template <typename, typename, typename> typename Epilogue_,
|
||||||
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
||||||
typename EpilogueSchedule>
|
typename EpilogueSchedule, bool swap_ab_ = false>
|
||||||
struct cutlass_3x_group_gemm {
|
struct cutlass_3x_group_gemm {
|
||||||
|
static constexpr bool swap_ab = swap_ab_;
|
||||||
using ElementAB = ElementAB_;
|
using ElementAB = ElementAB_;
|
||||||
using ElementC = void;
|
using ElementC = void;
|
||||||
using ElementD = ElementC_;
|
using ElementD = ElementC_;
|
||||||
@ -37,9 +46,6 @@ struct cutlass_3x_group_gemm {
|
|||||||
|
|
||||||
using Epilogue = Epilogue_<ElementAccumulator, ElementD, TileShape>;
|
using Epilogue = Epilogue_<ElementAccumulator, ElementD, TileShape>;
|
||||||
|
|
||||||
using StrideC =
|
|
||||||
cute::remove_pointer_t<cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>>;
|
|
||||||
|
|
||||||
static constexpr int AlignmentAB =
|
static constexpr int AlignmentAB =
|
||||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementD>::value;
|
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||||
@ -50,19 +56,26 @@ struct cutlass_3x_group_gemm {
|
|||||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||||
ArchTag, OperatorClass, TileShape, ClusterShape,
|
ArchTag, OperatorClass, TileShape, ClusterShape,
|
||||||
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
|
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
|
||||||
ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementD,
|
ElementAccumulator, ElementC,
|
||||||
LayoutC*, AlignmentC, EpilogueSchedule, EVTCompute>::CollectiveOp;
|
conditional_t<swap_ab, LayoutC_Transpose*, LayoutC*>, AlignmentC,
|
||||||
|
ElementD, conditional_t<swap_ab, LayoutD_Transpose*, LayoutD*>,
|
||||||
|
AlignmentC, EpilogueSchedule, EVTCompute>::CollectiveOp;
|
||||||
|
|
||||||
static constexpr size_t CEStorageSize =
|
static constexpr size_t CEStorageSize =
|
||||||
sizeof(typename CollectiveEpilogue::SharedStorage);
|
sizeof(typename CollectiveEpilogue::SharedStorage);
|
||||||
using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
|
using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
|
||||||
static_cast<int>(CEStorageSize)>;
|
static_cast<int>(CEStorageSize)>;
|
||||||
|
|
||||||
using CollectiveMainloop =
|
using CollectiveMainloop = conditional_t<
|
||||||
|
swap_ab,
|
||||||
|
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||||
|
ArchTag, OperatorClass, ElementAB, LayoutB_Transpose*, AlignmentAB,
|
||||||
|
ElementAB, LayoutA_Transpose*, AlignmentAB, ElementAccumulator,
|
||||||
|
TileShape, ClusterShape, Stages, KernelSchedule>::CollectiveOp,
|
||||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||||
ArchTag, OperatorClass, ElementAB, LayoutA*, AlignmentAB, ElementAB,
|
ArchTag, OperatorClass, ElementAB, LayoutA*, AlignmentAB, ElementAB,
|
||||||
LayoutB*, AlignmentAB, ElementAccumulator, TileShape, ClusterShape,
|
LayoutB*, AlignmentAB, ElementAccumulator, TileShape, ClusterShape,
|
||||||
Stages, KernelSchedule>::CollectiveOp;
|
Stages, KernelSchedule>::CollectiveOp>;
|
||||||
|
|
||||||
using KernelType = enable_sm90_only<cutlass::gemm::kernel::GemmUniversal<
|
using KernelType = enable_sm90_only<cutlass::gemm::kernel::GemmUniversal<
|
||||||
ProblemShape, CollectiveMainloop, CollectiveEpilogue>>;
|
ProblemShape, CollectiveMainloop, CollectiveEpilogue>>;
|
||||||
@ -78,12 +91,12 @@ void cutlass_group_gemm_caller(
|
|||||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||||
bool per_act_token, bool per_out_ch) {
|
bool per_act_token, bool per_out_ch) {
|
||||||
|
static constexpr bool swap_ab = Gemm::swap_ab;
|
||||||
|
|
||||||
using ElementAB = typename Gemm::ElementAB;
|
using ElementAB = typename Gemm::ElementAB;
|
||||||
using ElementD = typename Gemm::ElementD;
|
using ElementD = typename Gemm::ElementD;
|
||||||
|
|
||||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||||
int k_size = a_tensors.size(1);
|
|
||||||
int n_size = out_tensors.size(1);
|
|
||||||
|
|
||||||
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
||||||
|
|
||||||
@ -110,19 +123,35 @@ void cutlass_group_gemm_caller(
|
|||||||
problem_sizes.data_ptr());
|
problem_sizes.data_ptr());
|
||||||
ProblemShape prob_shape{num_experts, problem_sizes_as_shapes, nullptr};
|
ProblemShape prob_shape{num_experts, problem_sizes_as_shapes, nullptr};
|
||||||
|
|
||||||
typename GemmKernel::MainloopArguments mainloop_args{
|
typename GemmKernel::MainloopArguments mainloop_args;
|
||||||
static_cast<const ElementAB**>(a_ptrs.data_ptr()),
|
if constexpr (swap_ab) {
|
||||||
static_cast<StrideA*>(a_strides.data_ptr()),
|
mainloop_args = typename GemmKernel::MainloopArguments{
|
||||||
static_cast<const ElementAB**>(b_ptrs.data_ptr()),
|
static_cast<const ElementAB**>(b_ptrs.data_ptr()),
|
||||||
static_cast<StrideB*>(b_strides.data_ptr())};
|
static_cast<StrideB*>(b_strides.data_ptr()),
|
||||||
|
static_cast<const ElementAB**>(a_ptrs.data_ptr()),
|
||||||
|
static_cast<StrideA*>(a_strides.data_ptr())};
|
||||||
|
} else {
|
||||||
|
mainloop_args = typename GemmKernel::MainloopArguments{
|
||||||
|
static_cast<const ElementAB**>(a_ptrs.data_ptr()),
|
||||||
|
static_cast<StrideA*>(a_strides.data_ptr()),
|
||||||
|
static_cast<const ElementAB**>(b_ptrs.data_ptr()),
|
||||||
|
static_cast<StrideB*>(b_strides.data_ptr())};
|
||||||
|
}
|
||||||
|
|
||||||
// Currently, we are only able to do broadcast on either all or none a_scales
|
// Currently, we are only able to do broadcast on either all or none a_scales
|
||||||
// and on either all or none b_scales
|
// and on either all or none b_scales
|
||||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||||
Gemm::Epilogue::prepare_args(
|
Gemm::Epilogue::prepare_args(
|
||||||
static_cast<const ElementAccumulator**>(a_scales_ptrs.data_ptr()),
|
swap_ab ? static_cast<const ElementAccumulator**>(
|
||||||
static_cast<const ElementAccumulator**>(b_scales_ptrs.data_ptr()),
|
b_scales_ptrs.data_ptr())
|
||||||
per_act_token, per_out_ch),
|
: static_cast<const ElementAccumulator**>(
|
||||||
|
a_scales_ptrs.data_ptr()),
|
||||||
|
swap_ab ? static_cast<const ElementAccumulator**>(
|
||||||
|
a_scales_ptrs.data_ptr())
|
||||||
|
: static_cast<const ElementAccumulator**>(
|
||||||
|
b_scales_ptrs.data_ptr()),
|
||||||
|
swap_ab ? per_out_ch : per_act_token,
|
||||||
|
swap_ab ? per_act_token : per_out_ch),
|
||||||
nullptr, static_cast<StrideC*>(c_strides.data_ptr()),
|
nullptr, static_cast<StrideC*>(c_strides.data_ptr()),
|
||||||
static_cast<ElementD**>(out_ptrs.data_ptr()),
|
static_cast<ElementD**>(out_ptrs.data_ptr()),
|
||||||
static_cast<StrideC*>(c_strides.data_ptr())};
|
static_cast<StrideC*>(c_strides.data_ptr())};
|
||||||
|
|||||||
@ -6,7 +6,10 @@
|
|||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
constexpr uint64_t THREADS_PER_EXPERT = 512;
|
constexpr uint64_t THREADS_PER_EXPERT = 512;
|
||||||
|
// threshold must match the dispatch logic in run_cutlass_moe_mm_sm90()
|
||||||
|
constexpr int SWAP_AB_THRESHOLD = 64;
|
||||||
|
|
||||||
|
template <bool SWAP_AB>
|
||||||
__global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids,
|
__global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids,
|
||||||
int32_t* problem_sizes1,
|
int32_t* problem_sizes1,
|
||||||
int32_t* problem_sizes2,
|
int32_t* problem_sizes2,
|
||||||
@ -24,40 +27,53 @@ __global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids,
|
|||||||
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
int final_occurrences = atomic_buffer[expert_id];
|
int final_occurrences = atomic_buffer[expert_id];
|
||||||
problem_sizes1[expert_id * 3] = final_occurrences;
|
if constexpr (!SWAP_AB) {
|
||||||
problem_sizes1[expert_id * 3 + 1] = 2 * n;
|
problem_sizes1[expert_id * 3] = final_occurrences;
|
||||||
problem_sizes1[expert_id * 3 + 2] = k;
|
problem_sizes1[expert_id * 3 + 1] = 2 * n;
|
||||||
problem_sizes2[expert_id * 3] = final_occurrences;
|
problem_sizes1[expert_id * 3 + 2] = k;
|
||||||
problem_sizes2[expert_id * 3 + 1] = k;
|
problem_sizes2[expert_id * 3] = final_occurrences;
|
||||||
problem_sizes2[expert_id * 3 + 2] = n;
|
problem_sizes2[expert_id * 3 + 1] = k;
|
||||||
|
problem_sizes2[expert_id * 3 + 2] = n;
|
||||||
|
} else {
|
||||||
|
problem_sizes1[expert_id * 3] = 2 * n;
|
||||||
|
problem_sizes1[expert_id * 3 + 1] = final_occurrences;
|
||||||
|
problem_sizes1[expert_id * 3 + 2] = k;
|
||||||
|
problem_sizes2[expert_id * 3] = k;
|
||||||
|
problem_sizes2[expert_id * 3 + 1] = final_occurrences;
|
||||||
|
problem_sizes2[expert_id * 3 + 2] = n;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void compute_expert_offsets(
|
__global__ void compute_expert_offsets(
|
||||||
const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets,
|
const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets,
|
||||||
int32_t* atomic_buffer, const int num_experts) {
|
int32_t* atomic_buffer, const int num_experts, const int topk_length) {
|
||||||
int32_t tot_offset = 0;
|
int32_t tot_offset = 0;
|
||||||
expert_offsets[0] = 0;
|
expert_offsets[0] = 0;
|
||||||
for (int i = 0; i < num_experts; ++i) {
|
for (int i = 0; i < num_experts; ++i) {
|
||||||
atomic_buffer[i] = tot_offset;
|
atomic_buffer[i] = tot_offset;
|
||||||
tot_offset += problem_sizes1[i * 3];
|
tot_offset += topk_length > SWAP_AB_THRESHOLD ? problem_sizes1[i * 3]
|
||||||
|
: problem_sizes1[i * 3 + 1];
|
||||||
expert_offsets[i + 1] = tot_offset;
|
expert_offsets[i + 1] = tot_offset;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void compute_expert_blockscale_offsets(
|
__global__ void compute_expert_blockscale_offsets(
|
||||||
const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets,
|
const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets,
|
||||||
int32_t* blockscale_offsets, int32_t* atomic_buffer,
|
int32_t* blockscale_offsets, int32_t* atomic_buffer, const int num_experts,
|
||||||
const int num_experts) {
|
const int topk_length) {
|
||||||
int32_t tot_offset = 0;
|
int32_t tot_offset = 0;
|
||||||
int32_t tot_offset_round = 0;
|
int32_t tot_offset_round = 0;
|
||||||
expert_offsets[0] = 0;
|
expert_offsets[0] = 0;
|
||||||
blockscale_offsets[0] = 0;
|
blockscale_offsets[0] = 0;
|
||||||
for (int i = 0; i < num_experts; ++i) {
|
for (int i = 0; i < num_experts; ++i) {
|
||||||
|
int32_t cur_offset = topk_length > SWAP_AB_THRESHOLD
|
||||||
|
? problem_sizes1[i * 3]
|
||||||
|
: problem_sizes1[i * 3 + 1];
|
||||||
atomic_buffer[i] = tot_offset;
|
atomic_buffer[i] = tot_offset;
|
||||||
tot_offset += problem_sizes1[i * 3];
|
tot_offset += cur_offset;
|
||||||
expert_offsets[i + 1] = tot_offset;
|
expert_offsets[i + 1] = tot_offset;
|
||||||
tot_offset_round += (problem_sizes1[i * 3] + (128 - 1)) / 128 * 128;
|
tot_offset_round += (cur_offset + (128 - 1)) / 128 * 128;
|
||||||
blockscale_offsets[i + 1] = tot_offset_round;
|
blockscale_offsets[i + 1] = tot_offset_round;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -102,22 +118,36 @@ void get_cutlass_moe_mm_data_caller(
|
|||||||
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
|
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
|
||||||
|
|
||||||
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
||||||
compute_problem_sizes<<<num_experts, num_threads, 0, stream>>>(
|
|
||||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
if (topk_ids.numel() > SWAP_AB_THRESHOLD) {
|
||||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
compute_problem_sizes<false><<<num_experts, num_threads, 0, stream>>>(
|
||||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n, k);
|
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||||
|
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||||
|
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n,
|
||||||
|
k);
|
||||||
|
} else {
|
||||||
|
compute_problem_sizes<true><<<num_experts, num_threads, 0, stream>>>(
|
||||||
|
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||||
|
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||||
|
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||||
|
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n,
|
||||||
|
k);
|
||||||
|
}
|
||||||
|
|
||||||
if (blockscale_offsets.has_value()) {
|
if (blockscale_offsets.has_value()) {
|
||||||
compute_expert_blockscale_offsets<<<1, 1, 0, stream>>>(
|
compute_expert_blockscale_offsets<<<1, 1, 0, stream>>>(
|
||||||
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
|
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
|
||||||
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
||||||
static_cast<int32_t*>(blockscale_offsets.value().data_ptr()),
|
static_cast<int32_t*>(blockscale_offsets.value().data_ptr()),
|
||||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
|
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts,
|
||||||
|
topk_ids.numel());
|
||||||
} else {
|
} else {
|
||||||
compute_expert_offsets<<<1, 1, 0, stream>>>(
|
compute_expert_offsets<<<1, 1, 0, stream>>>(
|
||||||
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
|
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
|
||||||
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
||||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
|
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts,
|
||||||
|
topk_ids.numel());
|
||||||
}
|
}
|
||||||
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
|
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
|
||||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||||
|
|||||||
@ -25,6 +25,7 @@ MNK_FACTORS = [
|
|||||||
(2, 1024, 1536),
|
(2, 1024, 1536),
|
||||||
(2, 3072, 1024),
|
(2, 3072, 1024),
|
||||||
(2, 3072, 1536),
|
(2, 3072, 1536),
|
||||||
|
(7, 3072, 1536),
|
||||||
(64, 1024, 1024),
|
(64, 1024, 1024),
|
||||||
(64, 1024, 1536),
|
(64, 1024, 1536),
|
||||||
(64, 3072, 1024),
|
(64, 3072, 1024),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user