mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 21:01:51 +08:00
Signed-off-by: Duncan Moss <djm.moss@gmail.com> Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: mgoin <mgoin64@gmail.com>
199 lines
8.5 KiB
Plaintext
199 lines
8.5 KiB
Plaintext
#include <cudaTypedefs.h>
|
|
|
|
#include <c10/cuda/CUDAGuard.h>
|
|
#include <torch/all.h>
|
|
|
|
#include "cutlass/cutlass.h"
|
|
#include "grouped_mm_c3x.cuh"
|
|
|
|
using namespace cute;
|
|
|
|
namespace {
|
|
|
|
template <typename InType, typename OutType,
|
|
template <typename, typename, typename> typename Epilogue>
|
|
struct sm90_fp8_config_default {
|
|
// M in (16, inf)
|
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
|
using KernelSchedule =
|
|
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
|
|
using EpilogueSchedule =
|
|
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
|
using TileShape = cute::Shape<cute::_64, cute::_256, cute::_128>;
|
|
using ClusterShape = cute::Shape<cute::_1, cute::_2, cute::_1>;
|
|
using ArchTag = cutlass::arch::Sm90;
|
|
|
|
using Cutlass3xGemm =
|
|
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
|
|
ClusterShape, KernelSchedule, EpilogueSchedule>;
|
|
};
|
|
|
|
template <typename InType, typename OutType,
|
|
template <typename, typename, typename> typename Epilogue>
|
|
struct sm90_fp8_config_M4 {
|
|
// M in [1, 4]
|
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
|
using KernelSchedule =
|
|
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
|
|
using EpilogueSchedule =
|
|
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
|
using TileShape = cute::Shape<cute::_128, cute::_16, cute::_128>;
|
|
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
|
using ArchTag = cutlass::arch::Sm90;
|
|
|
|
using Cutlass3xGemm =
|
|
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
|
|
ClusterShape, KernelSchedule, EpilogueSchedule,
|
|
true>;
|
|
};
|
|
|
|
template <typename InType, typename OutType,
|
|
template <typename, typename, typename> typename Epilogue>
|
|
struct sm90_fp8_config_M64 {
|
|
// M in (4, 64]
|
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
|
using KernelSchedule =
|
|
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
|
|
using EpilogueSchedule =
|
|
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
|
using TileShape = cute::Shape<cute::_128, cute::_16, cute::_256>;
|
|
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
|
using ArchTag = cutlass::arch::Sm90;
|
|
|
|
using Cutlass3xGemm =
|
|
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
|
|
ClusterShape, KernelSchedule, EpilogueSchedule,
|
|
true>;
|
|
};
|
|
|
|
template <typename InType, typename OutType,
|
|
template <typename, typename, typename> typename Epilogue>
|
|
struct sm90_fp8_config_K8192 {
|
|
// K in [8192, inf)
|
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
|
using KernelSchedule =
|
|
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
|
|
using EpilogueSchedule =
|
|
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
|
using TileShape = cute::Shape<cute::_128, cute::_128, cute::_128>;
|
|
using ClusterShape = cute::Shape<cute::_1, cute::_8, cute::_1>;
|
|
using ArchTag = cutlass::arch::Sm90;
|
|
|
|
using Cutlass3xGemm =
|
|
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
|
|
ClusterShape, KernelSchedule, EpilogueSchedule>;
|
|
};
|
|
|
|
template <typename InType, typename OutType,
|
|
template <typename, typename, typename> typename Epilogue>
|
|
struct sm90_fp8_config_N8192 {
|
|
// N in [8192, inf)
|
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
|
using KernelSchedule =
|
|
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
|
|
using EpilogueSchedule =
|
|
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
|
using TileShape = cute::Shape<cute::_64, cute::_128, cute::_256>;
|
|
using ClusterShape = cute::Shape<cute::_1, cute::_8, cute::_1>;
|
|
using ArchTag = cutlass::arch::Sm90;
|
|
|
|
using Cutlass3xGemm =
|
|
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
|
|
ClusterShape, KernelSchedule, EpilogueSchedule>;
|
|
};
|
|
|
|
template <typename InType, typename OutType>
|
|
void run_cutlass_moe_mm_sm90(
|
|
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
|
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
|
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
|
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
|
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
|
bool per_act_token, bool per_out_ch) {
|
|
TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
|
|
TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
|
|
TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
|
|
|
|
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn,
|
|
"A tensors must be of type float8_e4m3fn.");
|
|
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn,
|
|
"B tensors must be of type float8_e4m3fn.");
|
|
|
|
using Cutlass3xGemmN8192 = typename sm90_fp8_config_N8192<
|
|
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
|
using Cutlass3xGemmK8192 = typename sm90_fp8_config_K8192<
|
|
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
|
using Cutlass3xGemmM4 = typename sm90_fp8_config_M4<
|
|
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
|
using Cutlass3xGemmM64 = typename sm90_fp8_config_M64<
|
|
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
|
using Cutlass3xGemmDefault = typename sm90_fp8_config_default<
|
|
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
|
|
|
uint32_t const m = a_tensors.size(0);
|
|
uint32_t const n = out_tensors.size(1);
|
|
uint32_t const k = a_tensors.size(1);
|
|
|
|
// Use swap_ab for M <= 64 by default to reduce padding
|
|
if (m <= 4) {
|
|
cutlass_group_gemm_caller<Cutlass3xGemmM4>(
|
|
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
|
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
|
per_out_ch);
|
|
} else if (m <= 64) {
|
|
cutlass_group_gemm_caller<Cutlass3xGemmM64>(
|
|
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
|
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
|
per_out_ch);
|
|
} else if (n >= 8192) {
|
|
cutlass_group_gemm_caller<Cutlass3xGemmN8192>(
|
|
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
|
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
|
per_out_ch);
|
|
} else if (k >= 8192) {
|
|
cutlass_group_gemm_caller<Cutlass3xGemmK8192>(
|
|
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
|
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
|
per_out_ch);
|
|
} else {
|
|
cutlass_group_gemm_caller<Cutlass3xGemmDefault>(
|
|
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
|
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
|
per_out_ch);
|
|
}
|
|
}
|
|
|
|
void dispatch_moe_mm_sm90(
|
|
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
|
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
|
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
|
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
|
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
|
bool per_act_token, bool per_out_ch) {
|
|
if (out_tensors.dtype() == torch::kBFloat16) {
|
|
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
|
|
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
|
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
|
per_out_ch);
|
|
} else {
|
|
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::half_t>(
|
|
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
|
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
|
per_out_ch);
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
void cutlass_moe_mm_sm90(
|
|
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
|
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
|
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
|
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
|
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
|
bool per_act_token, bool per_out_ch) {
|
|
dispatch_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
|
expert_offsets, problem_sizes, a_strides, b_strides,
|
|
c_strides, per_act_token, per_out_ch);
|
|
}
|