mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-24 19:34:31 +08:00
[feat]: add SM100 support for cutlass FP8 groupGEMM (#20447)
Signed-off-by: Duncan Moss <djm.moss@gmail.com> Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
4fb56914c5
commit
2c8db17cfd
@ -577,7 +577,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# if it's possible to compile MoE kernels that use its output.
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
|
||||
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu")
|
||||
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
@ -595,6 +595,26 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
|
||||
message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is "
|
||||
"not >= 12.8, we recommend upgrading to CUDA 12.8 or later "
|
||||
"if you intend on running FP8 quantized MoE models on Blackwell.")
|
||||
else()
|
||||
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
|
||||
"in CUDA target architectures.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# moe_data.cu is used by all CUTLASS MoE kernels.
|
||||
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
|
||||
|
||||
@ -18,7 +18,6 @@ using ProblemShape =
|
||||
cutlass::gemm::GroupProblemShape<cute::Shape<int, int, int>>;
|
||||
|
||||
using ElementAccumulator = float;
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
@ -33,7 +32,7 @@ using LayoutD_Transpose =
|
||||
using LayoutC = LayoutD;
|
||||
using LayoutC_Transpose = LayoutD_Transpose;
|
||||
|
||||
template <typename ElementAB_, typename ElementC_,
|
||||
template <typename ElementAB_, typename ElementC_, typename ArchTag_,
|
||||
template <typename, typename, typename> typename Epilogue_,
|
||||
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
||||
typename EpilogueSchedule, bool swap_ab_ = false>
|
||||
@ -43,6 +42,7 @@ struct cutlass_3x_group_gemm {
|
||||
using ElementC = void;
|
||||
using ElementD = ElementC_;
|
||||
using ElementAccumulator = float;
|
||||
using ArchTag = ArchTag_;
|
||||
|
||||
using Epilogue = Epilogue_<ElementAccumulator, ElementD, TileShape>;
|
||||
|
||||
@ -77,7 +77,7 @@ struct cutlass_3x_group_gemm {
|
||||
LayoutB*, AlignmentAB, ElementAccumulator, TileShape, ClusterShape,
|
||||
Stages, KernelSchedule>::CollectiveOp>;
|
||||
|
||||
using KernelType = enable_sm90_only<cutlass::gemm::kernel::GemmUniversal<
|
||||
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape, CollectiveMainloop, CollectiveEpilogue>>;
|
||||
|
||||
struct GemmKernel : public KernelType {};
|
||||
@ -156,9 +156,14 @@ void cutlass_group_gemm_caller(
|
||||
static_cast<ElementD**>(out_ptrs.data_ptr()),
|
||||
static_cast<StrideC*>(c_strides.data_ptr())};
|
||||
|
||||
int device_id = a_tensors.device().index();
|
||||
static const cutlass::KernelHardwareInfo hw_info{
|
||||
device_id, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
|
||||
device_id)};
|
||||
|
||||
typename GemmKernel::Arguments args{
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape, mainloop_args,
|
||||
epilogue_args};
|
||||
epilogue_args, hw_info};
|
||||
|
||||
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
GemmOp gemm_op;
|
||||
|
||||
140
csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu
Normal file
140
csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu
Normal file
@ -0,0 +1,140 @@
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "grouped_mm_c3x.cuh"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm100_fp8_config_default {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
|
||||
using TileShape = cute::Shape<cute::_128, cute::_256, cute::_128>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
|
||||
ClusterShape, KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm100_fp8_config_M64 {
|
||||
// M in [1,64]
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
|
||||
using TileShape = cute::Shape<cute::_128, cute::_16, cute::_128>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
|
||||
ClusterShape, KernelSchedule, EpilogueSchedule,
|
||||
true>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm100_fp8_config_N8192 {
|
||||
// N in [8192, inf)
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm;
|
||||
using TileShape = cute::Shape<cute::_128, cute::_256, cute::_128>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
|
||||
ClusterShape, KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType>
|
||||
void run_cutlass_moe_mm_sm100(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
|
||||
TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
|
||||
TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
|
||||
|
||||
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn,
|
||||
"A tensors must be of type float8_e4m3fn.");
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn,
|
||||
"B tensors must be of type float8_e4m3fn.");
|
||||
|
||||
using Cutlass3xGemmDefault = typename sm100_fp8_config_default<
|
||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmN8192 = typename sm100_fp8_config_N8192<
|
||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM64 = typename sm100_fp8_config_M64<
|
||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||
|
||||
uint32_t const m = a_tensors.size(0);
|
||||
uint32_t const n = out_tensors.size(1);
|
||||
|
||||
if (m <= 64) {
|
||||
cutlass_group_gemm_caller<Cutlass3xGemmM64>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||
per_out_ch);
|
||||
} else if (n >= 8192) {
|
||||
cutlass_group_gemm_caller<Cutlass3xGemmN8192>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||
per_out_ch);
|
||||
} else {
|
||||
cutlass_group_gemm_caller<Cutlass3xGemmDefault>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||
per_out_ch);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void dispatch_moe_mm_sm100(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
if (out_tensors.dtype() == torch::kBFloat16) {
|
||||
run_cutlass_moe_mm_sm100<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||
per_out_ch);
|
||||
} else {
|
||||
run_cutlass_moe_mm_sm100<cutlass::float_e4m3_t, cutlass::half_t>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||
per_out_ch);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_moe_mm_sm100(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
dispatch_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides, per_act_token, per_out_ch);
|
||||
}
|
||||
@ -21,10 +21,11 @@ struct sm90_fp8_config_default {
|
||||
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using TileShape = cute::Shape<cute::_64, cute::_256, cute::_128>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_2, cute::_1>;
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
|
||||
ClusterShape, KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
@ -38,10 +39,12 @@ struct sm90_fp8_config_M4 {
|
||||
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using TileShape = cute::Shape<cute::_128, cute::_16, cute::_128>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, true>;
|
||||
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
|
||||
ClusterShape, KernelSchedule, EpilogueSchedule,
|
||||
true>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
@ -55,10 +58,12 @@ struct sm90_fp8_config_M64 {
|
||||
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using TileShape = cute::Shape<cute::_128, cute::_16, cute::_256>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, true>;
|
||||
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
|
||||
ClusterShape, KernelSchedule, EpilogueSchedule,
|
||||
true>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
@ -72,10 +77,11 @@ struct sm90_fp8_config_K8192 {
|
||||
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using TileShape = cute::Shape<cute::_128, cute::_128, cute::_128>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_8, cute::_1>;
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
|
||||
ClusterShape, KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
@ -89,10 +95,11 @@ struct sm90_fp8_config_N8192 {
|
||||
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using TileShape = cute::Shape<cute::_64, cute::_128, cute::_256>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_8, cute::_1>;
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
|
||||
ClusterShape, KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType>
|
||||
@ -112,9 +119,6 @@ void run_cutlass_moe_mm_sm90(
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn,
|
||||
"B tensors must be of type float8_e4m3fn.");
|
||||
|
||||
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
using Cutlass3xGemmN8192 = typename sm90_fp8_config_N8192<
|
||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmK8192 = typename sm90_fp8_config_K8192<
|
||||
@ -190,4 +190,4 @@ void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
|
||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
|
||||
k);
|
||||
}
|
||||
}
|
||||
@ -41,6 +41,16 @@ void cutlass_moe_mm_sm90(
|
||||
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
|
||||
void cutlass_moe_mm_sm100(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch);
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
|
||||
void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
@ -130,10 +140,10 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
|
||||
// and at least SM90 (Hopper)
|
||||
|
||||
#if defined CUDA_VERSION
|
||||
if (cuda_device_capability >= 90 && cuda_device_capability < 100) {
|
||||
return CUDA_VERSION >= 12000;
|
||||
} else if (cuda_device_capability >= 100) {
|
||||
if (cuda_device_capability >= 100) {
|
||||
return CUDA_VERSION >= 12080;
|
||||
} else if (cuda_device_capability >= 90) {
|
||||
return CUDA_VERSION >= 12000;
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -141,11 +151,14 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
|
||||
}
|
||||
|
||||
bool cutlass_group_gemm_supported(int64_t cuda_device_capability) {
|
||||
// CUTLASS grouped FP8 kernels need at least CUDA 12.3
|
||||
// and SM90 (Hopper)
|
||||
// CUTLASS grouped FP8 kernels need at least CUDA 12.3 and SM90 (Hopper)
|
||||
// or CUDA 12.8 and SM100 (Blackwell)
|
||||
|
||||
#if defined CUDA_VERSION
|
||||
if (cuda_device_capability == 90) {
|
||||
if (cuda_device_capability >= 100) {
|
||||
return CUDA_VERSION >= 12080;
|
||||
}
|
||||
if (cuda_device_capability >= 90) {
|
||||
return CUDA_VERSION >= 12030;
|
||||
}
|
||||
#endif
|
||||
@ -234,16 +247,26 @@ void cutlass_moe_mm(
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
|
||||
if (version_num >= 100) {
|
||||
cutlass_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides, per_act_token, per_out_ch);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
|
||||
cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides, per_act_token, per_out_ch);
|
||||
return;
|
||||
if (version_num >= 90) {
|
||||
cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides, per_act_token, per_out_ch);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled cutlass_scaled_mm for CUDA device capability: ", version_num,
|
||||
". Required capability: 90");
|
||||
". Required capability: 90 or 100");
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_data(
|
||||
|
||||
@ -332,6 +332,12 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
return (self._check_scheme_supported(90, error=False, match_exact=True)
|
||||
and self._is_fp8_w8a8(weight_quant, input_quant))
|
||||
|
||||
def _is_fp8_w8a8_sm100(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> bool:
|
||||
return (self._check_scheme_supported(
|
||||
100, error=False, match_exact=True)
|
||||
and self._is_fp8_w8a8(weight_quant, input_quant))
|
||||
|
||||
def _is_fp8_w8a16(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> bool:
|
||||
# Confirm weights quantized.
|
||||
|
||||
@ -83,7 +83,8 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
return CompressedTensorsWNA16MarlinMoEMethod(quant_config)
|
||||
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
|
||||
return CompressedTensorsW4A4MoeMethod()
|
||||
elif quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant):
|
||||
elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
|
||||
or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant)):
|
||||
return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config)
|
||||
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
|
||||
return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
|
||||
@ -740,6 +741,8 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
|
||||
self.topk_indices_dtype = None
|
||||
self.fused_experts = None # type: ignore
|
||||
self.disable_expert_map = False
|
||||
self.is_fp8_w8a8_sm100 = self.quant_config._is_fp8_w8a8_sm100(
|
||||
self.weight_quant, self.input_quant)
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
@ -931,7 +934,29 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
per_act_token = (
|
||||
self.input_quant.strategy == QuantizationStrategy.TOKEN)
|
||||
|
||||
per_channel_quant = (
|
||||
self.weight_quant.strategy == QuantizationStrategy.CHANNEL)
|
||||
# Triton fused_experts is faster in small batch sizes on SM100.
|
||||
# Fall back to fused_experts in small batch sizes.
|
||||
if self.is_fp8_w8a8_sm100 and topk_ids.shape[0] <= 8:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
use_fp8_w8a8=True,
|
||||
per_channel_quant=per_channel_quant,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=None if self.disable_expert_map else expert_map,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale)
|
||||
if self.fused_experts is None:
|
||||
# If no modular kernel is provided, use cutlass_moe_fp8
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user