mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 06:45:00 +08:00
[Kernel] CUTLASS MoE FP8: Integrate cuda moe permute/unpermute (#23045)
Signed-off-by: Shixian Cui <shixian@amazon.com>
This commit is contained in:
parent
4449235843
commit
b17109beea
@ -80,6 +80,11 @@ def bench_run(
|
|||||||
a, score, topk, renormalize=False
|
a, score, topk, renormalize=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
|
||||||
|
ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
|
||||||
|
c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
|
||||||
|
c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
|
||||||
|
|
||||||
def run_triton_moe(
|
def run_triton_moe(
|
||||||
a: torch.Tensor,
|
a: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
@ -111,6 +116,10 @@ def bench_run(
|
|||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
w1_scale: torch.Tensor,
|
w1_scale: torch.Tensor,
|
||||||
w2_scale: torch.Tensor,
|
w2_scale: torch.Tensor,
|
||||||
|
ab_strides1: torch.Tensor,
|
||||||
|
ab_strides2: torch.Tensor,
|
||||||
|
c_strides1: torch.Tensor,
|
||||||
|
c_strides2: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
per_act_token: bool,
|
per_act_token: bool,
|
||||||
@ -125,6 +134,10 @@ def bench_run(
|
|||||||
topk_ids,
|
topk_ids,
|
||||||
w1_scale,
|
w1_scale,
|
||||||
w2_scale,
|
w2_scale,
|
||||||
|
ab_strides1,
|
||||||
|
ab_strides2,
|
||||||
|
c_strides1,
|
||||||
|
c_strides2,
|
||||||
per_act_token,
|
per_act_token,
|
||||||
a1_scale=None,
|
a1_scale=None,
|
||||||
)
|
)
|
||||||
@ -136,6 +149,10 @@ def bench_run(
|
|||||||
w2_q: torch.Tensor,
|
w2_q: torch.Tensor,
|
||||||
w1_scale: torch.Tensor,
|
w1_scale: torch.Tensor,
|
||||||
w2_scale: torch.Tensor,
|
w2_scale: torch.Tensor,
|
||||||
|
ab_strides1: torch.Tensor,
|
||||||
|
ab_strides2: torch.Tensor,
|
||||||
|
c_strides1: torch.Tensor,
|
||||||
|
c_strides2: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
):
|
):
|
||||||
@ -150,6 +167,10 @@ def bench_run(
|
|||||||
topk_ids,
|
topk_ids,
|
||||||
w1_scale,
|
w1_scale,
|
||||||
w2_scale,
|
w2_scale,
|
||||||
|
ab_strides1,
|
||||||
|
ab_strides2,
|
||||||
|
c_strides1,
|
||||||
|
c_strides2,
|
||||||
per_act_token,
|
per_act_token,
|
||||||
a1_scale=None,
|
a1_scale=None,
|
||||||
)
|
)
|
||||||
@ -194,6 +215,10 @@ def bench_run(
|
|||||||
w2_q,
|
w2_q,
|
||||||
w1_scale,
|
w1_scale,
|
||||||
w2_scale,
|
w2_scale,
|
||||||
|
ab_strides1,
|
||||||
|
ab_strides2,
|
||||||
|
c_strides1,
|
||||||
|
c_strides2,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
)
|
)
|
||||||
@ -231,6 +256,10 @@ def bench_run(
|
|||||||
"w1_scale": w1_scale,
|
"w1_scale": w1_scale,
|
||||||
"w2_scale": w2_scale,
|
"w2_scale": w2_scale,
|
||||||
"per_act_token": per_act_token,
|
"per_act_token": per_act_token,
|
||||||
|
"ab_strides1": ab_strides1,
|
||||||
|
"ab_strides2": ab_strides2,
|
||||||
|
"c_strides1": c_strides1,
|
||||||
|
"c_strides2": c_strides2,
|
||||||
# cuda graph params
|
# cuda graph params
|
||||||
"cutlass_graph": cutlass_graph,
|
"cutlass_graph": cutlass_graph,
|
||||||
"triton_graph": triton_graph,
|
"triton_graph": triton_graph,
|
||||||
@ -289,6 +318,10 @@ def bench_run(
|
|||||||
w2_q,
|
w2_q,
|
||||||
w1_scale,
|
w1_scale,
|
||||||
w2_scale,
|
w2_scale,
|
||||||
|
ab_strides1,
|
||||||
|
ab_strides2,
|
||||||
|
c_strides1,
|
||||||
|
c_strides2,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
per_act_token,
|
per_act_token,
|
||||||
@ -297,7 +330,7 @@ def bench_run(
|
|||||||
|
|
||||||
results.append(
|
results.append(
|
||||||
benchmark.Timer(
|
benchmark.Timer(
|
||||||
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
|
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
|
||||||
globals=globals,
|
globals=globals,
|
||||||
label=label,
|
label=label,
|
||||||
sub_label=sub_label,
|
sub_label=sub_label,
|
||||||
|
|||||||
@ -45,8 +45,6 @@ void moe_permute(
|
|||||||
auto copy_topk_ids = topk_ids.clone(); // copy topk_ids for preprocess
|
auto copy_topk_ids = topk_ids.clone(); // copy topk_ids for preprocess
|
||||||
auto permuted_experts_id = torch::empty_like(topk_ids);
|
auto permuted_experts_id = torch::empty_like(topk_ids);
|
||||||
auto sorted_row_idx = torch::empty_like(inv_permuted_idx);
|
auto sorted_row_idx = torch::empty_like(inv_permuted_idx);
|
||||||
auto align_expert_first_token_offset =
|
|
||||||
torch::zeros_like(expert_first_token_offset);
|
|
||||||
|
|
||||||
CubKeyValueSorter sorter{};
|
CubKeyValueSorter sorter{};
|
||||||
int64_t* valid_num_ptr = nullptr;
|
int64_t* valid_num_ptr = nullptr;
|
||||||
@ -85,12 +83,14 @@ void moe_permute(
|
|||||||
});
|
});
|
||||||
|
|
||||||
// get m_indices and update expert_first_token_offset with align block
|
// get m_indices and update expert_first_token_offset with align block
|
||||||
getMIndices(get_ptr<int64_t>(expert_first_token_offset),
|
// this is only required for DeepGemm and not required for CUTLASS group gemm
|
||||||
get_ptr<int64_t>(align_expert_first_token_offset),
|
|
||||||
get_ptr<int>(m_indices), n_local_expert, align_block_size_value,
|
|
||||||
stream);
|
|
||||||
if (align_block_size.has_value()) {
|
if (align_block_size.has_value()) {
|
||||||
// update align_expert_first_token_offset
|
auto align_expert_first_token_offset =
|
||||||
|
torch::zeros_like(expert_first_token_offset);
|
||||||
|
getMIndices(get_ptr<int64_t>(expert_first_token_offset),
|
||||||
|
get_ptr<int64_t>(align_expert_first_token_offset),
|
||||||
|
get_ptr<int>(m_indices), n_local_expert, align_block_size_value,
|
||||||
|
stream);
|
||||||
expert_first_token_offset.copy_(align_expert_first_token_offset);
|
expert_first_token_offset.copy_(align_expert_first_token_offset);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -195,19 +195,14 @@ void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights,
|
|||||||
torch::Tensor& expert_first_token_offset,
|
torch::Tensor& expert_first_token_offset,
|
||||||
torch::Tensor& src_row_id2dst_row_id_map,
|
torch::Tensor& src_row_id2dst_row_id_map,
|
||||||
torch::Tensor& m_indices) {
|
torch::Tensor& m_indices) {
|
||||||
TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0");
|
TORCH_CHECK(false, "moe_permute is not supported on CUDA < 12.0");
|
||||||
}
|
}
|
||||||
|
|
||||||
void moe_unpermute(const torch::Tensor& input,
|
void moe_unpermute(
|
||||||
const torch::Tensor& topk_weights, torch::Tensor& topk_ids,
|
const torch::Tensor& permuted_hidden_states,
|
||||||
const torch::Tensor& token_expert_indices,
|
const torch::Tensor& topk_weights, const torch::Tensor& inv_permuted_idx,
|
||||||
const std::optional<torch::Tensor>& expert_map,
|
const std::optional<torch::Tensor>& expert_first_token_offset, int64_t topk,
|
||||||
int64_t n_expert, int64_t n_local_expert, int64_t topk,
|
torch::Tensor& hidden_states) {
|
||||||
const std::optional<int64_t>& align_block_size,
|
|
||||||
torch::Tensor& permuted_input,
|
|
||||||
torch::Tensor& expert_first_token_offset,
|
|
||||||
torch::Tensor& src_row_id2dst_row_id_map,
|
|
||||||
torch::Tensor& m_indices) {
|
|
||||||
TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0");
|
TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -224,4 +219,4 @@ bool moe_permute_unpermute_supported() {
|
|||||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||||
m.impl("moe_permute", &moe_permute);
|
m.impl("moe_permute", &moe_permute);
|
||||||
m.impl("moe_unpermute", &moe_unpermute);
|
m.impl("moe_unpermute", &moe_unpermute);
|
||||||
}
|
}
|
||||||
@ -229,6 +229,11 @@ void get_cutlass_moe_mm_data(
|
|||||||
const int64_t num_experts, const int64_t n, const int64_t k,
|
const int64_t num_experts, const int64_t n, const int64_t k,
|
||||||
const std::optional<torch::Tensor>& blockscale_offsets);
|
const std::optional<torch::Tensor>& blockscale_offsets);
|
||||||
|
|
||||||
|
void get_cutlass_moe_mm_problem_sizes(
|
||||||
|
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
|
||||||
|
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
|
||||||
|
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets);
|
||||||
|
|
||||||
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
|
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
|
||||||
torch::Tensor& problem_sizes1,
|
torch::Tensor& problem_sizes1,
|
||||||
torch::Tensor& problem_sizes2,
|
torch::Tensor& problem_sizes2,
|
||||||
|
|||||||
@ -10,7 +10,7 @@
|
|||||||
|
|
||||||
template <typename ElementAB, typename ElementC, typename ElementAccumulator>
|
template <typename ElementAB, typename ElementC, typename ElementAccumulator>
|
||||||
__global__ void get_group_gemm_starts(
|
__global__ void get_group_gemm_starts(
|
||||||
int32_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets,
|
int64_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets,
|
||||||
ElementC** out_offsets, ElementAccumulator** a_scales_offsets,
|
ElementC** out_offsets, ElementAccumulator** a_scales_offsets,
|
||||||
ElementAccumulator** b_scales_offsets, ElementAB* a_base_as_int,
|
ElementAccumulator** b_scales_offsets, ElementAB* a_base_as_int,
|
||||||
ElementAB* b_base_as_int, ElementC* out_base_as_int,
|
ElementAB* b_base_as_int, ElementC* out_base_as_int,
|
||||||
@ -34,7 +34,7 @@ __global__ void get_group_gemm_starts(
|
|||||||
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
|
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
|
||||||
get_group_gemm_starts<cutlass::float_e4m3_t, C_TYPE, float> \
|
get_group_gemm_starts<cutlass::float_e4m3_t, C_TYPE, float> \
|
||||||
<<<1, num_experts, 0, stream>>>( \
|
<<<1, num_experts, 0, stream>>>( \
|
||||||
static_cast<int32_t*>(expert_offsets.data_ptr()), \
|
static_cast<int64_t*>(expert_offsets.data_ptr()), \
|
||||||
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
|
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
|
||||||
static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()), \
|
static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()), \
|
||||||
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
|
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
|
||||||
@ -61,6 +61,8 @@ void run_get_group_gemm_starts(
|
|||||||
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
|
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||||
|
// expect int64_t to avoid overflow during offset calculations
|
||||||
|
TORCH_CHECK(expert_offsets.dtype() == torch::kInt64);
|
||||||
|
|
||||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||||
bool per_act_token = a_scales.numel() != 1;
|
bool per_act_token = a_scales.numel() != 1;
|
||||||
|
|||||||
@ -104,6 +104,53 @@ __global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids,
|
||||||
|
torch::Tensor& problem_sizes1,
|
||||||
|
torch::Tensor& problem_sizes2,
|
||||||
|
torch::Tensor& atomic_buffer,
|
||||||
|
int64_t num_experts, int64_t n,
|
||||||
|
int64_t k, cudaStream_t stream,
|
||||||
|
const bool swap_ab) {
|
||||||
|
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
||||||
|
|
||||||
|
const int32_t* topk_ptr = static_cast<const int32_t*>(topk_ids.data_ptr());
|
||||||
|
int32_t* ps1_ptr = static_cast<int32_t*>(problem_sizes1.data_ptr());
|
||||||
|
int32_t* ps2_ptr = static_cast<int32_t*>(problem_sizes2.data_ptr());
|
||||||
|
int32_t* atomic_ptr = static_cast<int32_t*>(atomic_buffer.data_ptr());
|
||||||
|
|
||||||
|
if (swap_ab) {
|
||||||
|
compute_problem_sizes<true><<<num_experts, num_threads, 0, stream>>>(
|
||||||
|
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
|
||||||
|
static_cast<int>(topk_ids.numel()), static_cast<int>(n),
|
||||||
|
static_cast<int>(k));
|
||||||
|
} else {
|
||||||
|
compute_problem_sizes<false><<<num_experts, num_threads, 0, stream>>>(
|
||||||
|
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
|
||||||
|
static_cast<int>(topk_ids.numel()), static_cast<int>(n),
|
||||||
|
static_cast<int>(k));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void get_cutlass_moe_mm_problem_sizes_caller(
|
||||||
|
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
|
||||||
|
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
|
||||||
|
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets) {
|
||||||
|
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
|
||||||
|
auto options_int32 =
|
||||||
|
torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
|
||||||
|
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
|
||||||
|
|
||||||
|
// Swap-AB should be disabled for FP4 path
|
||||||
|
bool may_swap_ab = (!blockscale_offsets.has_value()) &&
|
||||||
|
(topk_ids.numel() <= SWAP_AB_THRESHOLD);
|
||||||
|
|
||||||
|
launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2,
|
||||||
|
atomic_buffer, num_experts, n, k, stream,
|
||||||
|
may_swap_ab);
|
||||||
|
}
|
||||||
|
|
||||||
void get_cutlass_moe_mm_data_caller(
|
void get_cutlass_moe_mm_data_caller(
|
||||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||||
@ -121,21 +168,9 @@ void get_cutlass_moe_mm_data_caller(
|
|||||||
bool may_swap_ab = (!blockscale_offsets.has_value()) &&
|
bool may_swap_ab = (!blockscale_offsets.has_value()) &&
|
||||||
(topk_ids.numel() <= SWAP_AB_THRESHOLD);
|
(topk_ids.numel() <= SWAP_AB_THRESHOLD);
|
||||||
|
|
||||||
if (may_swap_ab) {
|
launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2,
|
||||||
compute_problem_sizes<true><<<num_experts, num_threads, 0, stream>>>(
|
atomic_buffer, num_experts, n, k, stream,
|
||||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
may_swap_ab);
|
||||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
|
||||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
|
||||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n,
|
|
||||||
k);
|
|
||||||
} else {
|
|
||||||
compute_problem_sizes<false><<<num_experts, num_threads, 0, stream>>>(
|
|
||||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
|
||||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
|
||||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
|
||||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n,
|
|
||||||
k);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (blockscale_offsets.has_value()) {
|
if (blockscale_offsets.has_value()) {
|
||||||
// fp4 path
|
// fp4 path
|
||||||
|
|||||||
@ -76,6 +76,11 @@ void get_cutlass_moe_mm_data_caller(
|
|||||||
const int64_t num_experts, const int64_t n, const int64_t k,
|
const int64_t num_experts, const int64_t n, const int64_t k,
|
||||||
const std::optional<torch::Tensor>& blockscale_offsets);
|
const std::optional<torch::Tensor>& blockscale_offsets);
|
||||||
|
|
||||||
|
void get_cutlass_moe_mm_problem_sizes_caller(
|
||||||
|
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
|
||||||
|
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
|
||||||
|
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets);
|
||||||
|
|
||||||
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
|
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
|
||||||
torch::Tensor& problem_sizes1,
|
torch::Tensor& problem_sizes1,
|
||||||
torch::Tensor& problem_sizes2,
|
torch::Tensor& problem_sizes2,
|
||||||
@ -293,6 +298,25 @@ void get_cutlass_moe_mm_data(
|
|||||||
version_num, ". Required capability: 90 or 100");
|
version_num, ". Required capability: 90 or 100");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void get_cutlass_moe_mm_problem_sizes(
|
||||||
|
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
|
||||||
|
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
|
||||||
|
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets) {
|
||||||
|
int32_t version_num = get_sm_version_num();
|
||||||
|
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
||||||
|
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100)
|
||||||
|
get_cutlass_moe_mm_problem_sizes_caller(topk_ids, problem_sizes1,
|
||||||
|
problem_sizes2, num_experts, n, k,
|
||||||
|
blockscale_offsets);
|
||||||
|
return;
|
||||||
|
#endif
|
||||||
|
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
|
false,
|
||||||
|
"No compiled get_cutlass_moe_mm_problem_sizes: no cutlass_scaled_mm "
|
||||||
|
"kernel for CUDA device capability: ",
|
||||||
|
version_num, ". Required capability: 90 or 100");
|
||||||
|
}
|
||||||
|
|
||||||
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
|
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
|
||||||
torch::Tensor& problem_sizes1,
|
torch::Tensor& problem_sizes1,
|
||||||
torch::Tensor& problem_sizes2,
|
torch::Tensor& problem_sizes2,
|
||||||
|
|||||||
@ -440,6 +440,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
{stride_tag});
|
{stride_tag});
|
||||||
ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data);
|
ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data);
|
||||||
|
|
||||||
|
// A function that computes problem sizes for each expert's multiplication
|
||||||
|
// used by the two mms called from fused MoE operation. It takes topk_ids as
|
||||||
|
// an input, and computes problem_sizes1 and problem_sizes2 only.
|
||||||
|
ops.def(
|
||||||
|
"get_cutlass_moe_mm_problem_sizes(Tensor topk_ids, "
|
||||||
|
" Tensor! problem_sizes1, "
|
||||||
|
" Tensor! problem_sizes2, "
|
||||||
|
" int num_experts, int n, int k, "
|
||||||
|
" Tensor? blockscale_offsets) -> ()",
|
||||||
|
{stride_tag});
|
||||||
|
ops.impl("get_cutlass_moe_mm_problem_sizes", torch::kCUDA,
|
||||||
|
&get_cutlass_moe_mm_problem_sizes);
|
||||||
|
|
||||||
// A function that computes data required to run fused MoE with w8a8 grouped
|
// A function that computes data required to run fused MoE with w8a8 grouped
|
||||||
// GEMM and PPLX. It takes expert_num_tokens and non_zero_expert_idxs
|
// GEMM and PPLX. It takes expert_num_tokens and non_zero_expert_idxs
|
||||||
// as an input, and computes expert_offsets (token start indices of each
|
// as an input, and computes expert_offsets (token start indices of each
|
||||||
|
|||||||
@ -207,6 +207,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
|
|||||||
'topk_ids': topk_ids,
|
'topk_ids': topk_ids,
|
||||||
'w1_scale': moe_tensors.w1_scale,
|
'w1_scale': moe_tensors.w1_scale,
|
||||||
'w2_scale': moe_tensors.w2_scale,
|
'w2_scale': moe_tensors.w2_scale,
|
||||||
|
'ab_strides1': moe_tensors.ab_strides1,
|
||||||
|
'ab_strides2': moe_tensors.ab_strides2,
|
||||||
|
'c_strides1': moe_tensors.c_strides1,
|
||||||
|
'c_strides2': moe_tensors.c_strides2,
|
||||||
'per_act_token': per_act_token,
|
'per_act_token': per_act_token,
|
||||||
'a1_scale': None #moe_tensors.a_scale
|
'a1_scale': None #moe_tensors.a_scale
|
||||||
}
|
}
|
||||||
@ -424,8 +428,8 @@ def test_run_cutlass_moe_fp8(
|
|||||||
topk_ids[0][1] = 1
|
topk_ids[0][1] = 1
|
||||||
|
|
||||||
workspace13_shape = (m * topk, max(2 * n, k))
|
workspace13_shape = (m * topk, max(2 * n, k))
|
||||||
workspace2_shape = (m * topk, n)
|
workspace2_shape = (m * topk, max(n, k))
|
||||||
output_shape = (m * topk, k)
|
output_shape = (m, k)
|
||||||
|
|
||||||
workspace13 = torch.empty(prod(workspace13_shape),
|
workspace13 = torch.empty(prod(workspace13_shape),
|
||||||
device="cuda",
|
device="cuda",
|
||||||
@ -440,6 +444,11 @@ def test_run_cutlass_moe_fp8(
|
|||||||
expert_map[start:end] = list(range(num_local_experts))
|
expert_map[start:end] = list(range(num_local_experts))
|
||||||
expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")
|
expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")
|
||||||
|
|
||||||
|
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||||
|
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
|
||||||
|
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
|
||||||
|
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||||
|
|
||||||
activation = lambda o, i: torch.ops._C.silu_and_mul(o, i)
|
activation = lambda o, i: torch.ops._C.silu_and_mul(o, i)
|
||||||
a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale,
|
a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale,
|
||||||
torch.float8_e4m3fn,
|
torch.float8_e4m3fn,
|
||||||
@ -448,8 +457,9 @@ def test_run_cutlass_moe_fp8(
|
|||||||
func = lambda output: run_cutlass_moe_fp8(
|
func = lambda output: run_cutlass_moe_fp8(
|
||||||
output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation,
|
output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation,
|
||||||
global_num_experts, expert_map, mt.w1_scale, mt.w2_scale,
|
global_num_experts, expert_map, mt.w1_scale, mt.w2_scale,
|
||||||
a1q_scale, None, workspace13, workspace2, None, mt.a.dtype,
|
a1q_scale, None, ab_strides1, ab_strides2, c_strides1, c_strides2,
|
||||||
per_act_token, per_out_channel, False)
|
workspace13, workspace2, None, mt.a.dtype, per_act_token,
|
||||||
|
per_out_channel, False, topk_weights)
|
||||||
|
|
||||||
workspace13.random_()
|
workspace13.random_()
|
||||||
output_random_workspace = torch.empty(output_shape,
|
output_random_workspace = torch.empty(output_shape,
|
||||||
|
|||||||
@ -238,7 +238,11 @@ def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int,
|
|||||||
atol=0,
|
atol=0,
|
||||||
rtol=0)
|
rtol=0)
|
||||||
# check mindice
|
# check mindice
|
||||||
torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0)
|
# current kernel usage assumes deepgemm requires align_block_size
|
||||||
|
# when it's not provided then we don't compute m_indices (for cutlass)
|
||||||
|
if align_block_size is not None:
|
||||||
|
torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0)
|
||||||
|
|
||||||
# check permuted_hidden_states, only valid token
|
# check permuted_hidden_states, only valid token
|
||||||
torch.testing.assert_close(gold_permuted_hidden_states[valid_row_idx],
|
torch.testing.assert_close(gold_permuted_hidden_states[valid_row_idx],
|
||||||
permuted_hidden_states[valid_row_idx],
|
permuted_hidden_states[valid_row_idx],
|
||||||
|
|||||||
@ -76,6 +76,7 @@ def pplx_cutlass_moe(
|
|||||||
assert torch.cuda.current_device() == pgi.local_rank
|
assert torch.cuda.current_device() == pgi.local_rank
|
||||||
|
|
||||||
num_tokens, hidden_dim = a.shape
|
num_tokens, hidden_dim = a.shape
|
||||||
|
intermediate_dim = w2.shape[2]
|
||||||
num_experts = w1.shape[0]
|
num_experts = w1.shape[0]
|
||||||
block_size = hidden_dim # TODO support more cases
|
block_size = hidden_dim # TODO support more cases
|
||||||
device = pgi.device
|
device = pgi.device
|
||||||
@ -124,8 +125,27 @@ def pplx_cutlass_moe(
|
|||||||
num_local_experts=num_local_experts,
|
num_local_experts=num_local_experts,
|
||||||
num_dispatchers=num_dispatchers)
|
num_dispatchers=num_dispatchers)
|
||||||
|
|
||||||
|
ab_strides1 = torch.full((num_local_experts, ),
|
||||||
|
hidden_dim,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.int64)
|
||||||
|
ab_strides2 = torch.full((num_local_experts, ),
|
||||||
|
intermediate_dim,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.int64)
|
||||||
|
c_strides1 = torch.full((num_local_experts, ),
|
||||||
|
2 * intermediate_dim,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.int64)
|
||||||
|
c_strides2 = torch.full((num_local_experts, ),
|
||||||
|
hidden_dim,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.int64)
|
||||||
|
|
||||||
experts = CutlassBatchedExpertsFp8(num_local_experts, num_dispatchers,
|
experts = CutlassBatchedExpertsFp8(num_local_experts, num_dispatchers,
|
||||||
out_dtype, per_act_token, per_out_ch)
|
out_dtype, per_act_token, per_out_ch,
|
||||||
|
ab_strides1, ab_strides2, c_strides1,
|
||||||
|
c_strides2)
|
||||||
|
|
||||||
fused_cutlass_experts = FusedMoEModularKernel(
|
fused_cutlass_experts = FusedMoEModularKernel(
|
||||||
prepare_finalize,
|
prepare_finalize,
|
||||||
|
|||||||
@ -535,7 +535,7 @@ def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool,
|
|||||||
|
|
||||||
expert_offsets = torch.zeros((num_experts + 1),
|
expert_offsets = torch.zeros((num_experts + 1),
|
||||||
device=device,
|
device=device,
|
||||||
dtype=torch.int32)
|
dtype=torch.int64)
|
||||||
|
|
||||||
problem_sizes = torch.zeros((num_experts, 3),
|
problem_sizes = torch.zeros((num_experts, 3),
|
||||||
device=device,
|
device=device,
|
||||||
|
|||||||
@ -844,6 +844,28 @@ def get_cutlass_moe_mm_data(topk_ids: torch.Tensor,
|
|||||||
blockscale_offsets)
|
blockscale_offsets)
|
||||||
|
|
||||||
|
|
||||||
|
def get_cutlass_moe_mm_problem_sizes(
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
problem_sizes1: torch.Tensor,
|
||||||
|
problem_sizes2: torch.Tensor,
|
||||||
|
num_experts: int,
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
blockscale_offsets: Optional[torch.Tensor] = None):
|
||||||
|
"""
|
||||||
|
Compute only the per-expert problem sizes needed by the two grouped matrix
|
||||||
|
multiplications used in CUTLASS-based fused MoE.
|
||||||
|
|
||||||
|
The function takes in topk_ids (token→expert mapping) and computes:
|
||||||
|
- problem_sizes1, problem_sizes2: M×N×K sizes of each expert's
|
||||||
|
multiplication for the two grouped MMs
|
||||||
|
used in the fused MoE operation.
|
||||||
|
"""
|
||||||
|
return torch.ops._C.get_cutlass_moe_mm_problem_sizes(
|
||||||
|
topk_ids, problem_sizes1, problem_sizes2, num_experts, n, k,
|
||||||
|
blockscale_offsets)
|
||||||
|
|
||||||
|
|
||||||
def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor):
|
def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
Shuffle and expand the input tensor according to the dst2src_map and store the result in output_tensor.
|
Shuffle and expand the input tensor according to the dst2src_map and store the result in output_tensor.
|
||||||
|
|||||||
@ -9,12 +9,13 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
|||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||||
|
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
||||||
|
moe_permute, moe_unpermute)
|
||||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||||
MoEPrepareAndFinalizeNoEP)
|
MoEPrepareAndFinalizeNoEP)
|
||||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||||
TopKWeightAndReduceDelegate, TopKWeightAndReduceNoOP)
|
TopKWeightAndReduceDelegate, TopKWeightAndReduceNoOP)
|
||||||
from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm,
|
from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
|
||||||
_fp8_quantize,
|
|
||||||
_resize_cache)
|
_resize_cache)
|
||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import scalar_types
|
||||||
|
|
||||||
@ -34,6 +35,10 @@ def run_cutlass_moe_fp8(
|
|||||||
w2_scale: Optional[torch.Tensor],
|
w2_scale: Optional[torch.Tensor],
|
||||||
a1q_scale: Optional[torch.Tensor],
|
a1q_scale: Optional[torch.Tensor],
|
||||||
a2_scale: Optional[torch.Tensor],
|
a2_scale: Optional[torch.Tensor],
|
||||||
|
ab_strides1: torch.Tensor,
|
||||||
|
ab_strides2: torch.Tensor,
|
||||||
|
c_strides1: torch.Tensor,
|
||||||
|
c_strides2: torch.Tensor,
|
||||||
workspace13: torch.Tensor,
|
workspace13: torch.Tensor,
|
||||||
workspace2: torch.Tensor,
|
workspace2: torch.Tensor,
|
||||||
expert_num_tokens: Optional[torch.Tensor],
|
expert_num_tokens: Optional[torch.Tensor],
|
||||||
@ -41,6 +46,7 @@ def run_cutlass_moe_fp8(
|
|||||||
per_act_token: bool,
|
per_act_token: bool,
|
||||||
per_out_ch: bool,
|
per_out_ch: bool,
|
||||||
use_batched_format: bool,
|
use_batched_format: bool,
|
||||||
|
topk_weights: Optional[torch.Tensor],
|
||||||
):
|
):
|
||||||
a1q = hidden_states
|
a1q = hidden_states
|
||||||
|
|
||||||
@ -99,6 +105,22 @@ def run_cutlass_moe_fp8(
|
|||||||
topk = local_topk_ids.size(1)
|
topk = local_topk_ids.size(1)
|
||||||
local_E = w1.size(0)
|
local_E = w1.size(0)
|
||||||
|
|
||||||
|
if use_batched_format:
|
||||||
|
mm1_out = _resize_cache(workspace13, (local_E * padded_M, N * 2))
|
||||||
|
act_out = _resize_cache(workspace2, (local_E * padded_M, N))
|
||||||
|
quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn),
|
||||||
|
(local_E * padded_M, N))
|
||||||
|
mm2_out = _resize_cache(workspace2, (local_E * padded_M, K))
|
||||||
|
else:
|
||||||
|
a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn),
|
||||||
|
(M * topk, K))
|
||||||
|
mm1_out = _resize_cache(workspace13, (M * topk, N * 2))
|
||||||
|
act_out = _resize_cache(workspace2, (M * topk, N))
|
||||||
|
# original workspace are based on input hidden_states dtype (bf16)
|
||||||
|
quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn),
|
||||||
|
(M * topk, N))
|
||||||
|
mm2_out = _resize_cache(workspace2, (M * topk, K))
|
||||||
|
|
||||||
if use_batched_format:
|
if use_batched_format:
|
||||||
assert expert_num_tokens is not None
|
assert expert_num_tokens is not None
|
||||||
|
|
||||||
@ -120,11 +142,10 @@ def run_cutlass_moe_fp8(
|
|||||||
w2_scale = w2_scale.reshape(w2_scale.size(0), -1)
|
w2_scale = w2_scale.reshape(w2_scale.size(0), -1)
|
||||||
a1q = a1q.reshape(-1, a1q.size(2))
|
a1q = a1q.reshape(-1, a1q.size(2))
|
||||||
a1q_scale = a1q_scale.reshape(-1, a1q_scale.size(2)).contiguous()
|
a1q_scale = a1q_scale.reshape(-1, a1q_scale.size(2)).contiguous()
|
||||||
|
# c3x get_group_gemm_starts expects int64 to avoid overflow
|
||||||
|
# during offset calculations
|
||||||
|
expert_offsets = expert_offsets.to(torch.int64)
|
||||||
else:
|
else:
|
||||||
expert_offsets = torch.empty((global_num_experts + 1),
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=device)
|
|
||||||
problem_sizes1 = torch.empty((global_num_experts, 3),
|
problem_sizes1 = torch.empty((global_num_experts, 3),
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=device)
|
device=device)
|
||||||
@ -132,84 +153,57 @@ def run_cutlass_moe_fp8(
|
|||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=device)
|
device=device)
|
||||||
|
|
||||||
# With expert_map each Rank processes only a subset of experts. As
|
num_expert = global_num_experts if expert_map is None \
|
||||||
# a result not all of a_map and c2 tensors are filled. We fill it
|
else expert_map.size(0)
|
||||||
# zeros for correctness.
|
# permuted a1q reuses workspace2
|
||||||
if expert_map is not None:
|
a1q, a1q_scale, expert_offsets, inv_perm, _ = moe_permute(
|
||||||
a_map = torch.zeros((local_topk_ids.numel()),
|
a1q,
|
||||||
dtype=torch.int32,
|
a1q_scale,
|
||||||
device=device)
|
topk_ids,
|
||||||
else:
|
num_expert,
|
||||||
a_map = torch.empty((local_topk_ids.numel()),
|
local_E,
|
||||||
dtype=torch.int32,
|
expert_map,
|
||||||
device=device)
|
permuted_hidden_states=a1q_perm)
|
||||||
|
|
||||||
c_map = torch.empty((local_topk_ids.numel()),
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=device)
|
|
||||||
|
|
||||||
ops.get_cutlass_moe_mm_data(local_topk_ids, expert_offsets,
|
|
||||||
problem_sizes1, problem_sizes2, a_map,
|
|
||||||
c_map, global_num_experts, N, K)
|
|
||||||
|
|
||||||
a1q = _fp8_perm(a1q, a_map)
|
|
||||||
a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale
|
|
||||||
expert_offsets = expert_offsets[:-1]
|
expert_offsets = expert_offsets[:-1]
|
||||||
|
|
||||||
ab_strides1 = torch.full((w1.size(0), ),
|
ops.get_cutlass_moe_mm_problem_sizes(local_topk_ids, problem_sizes1,
|
||||||
K,
|
problem_sizes2,
|
||||||
device=device,
|
global_num_experts, N, K)
|
||||||
dtype=torch.int64)
|
|
||||||
c_strides1 = torch.full((w1.size(0), ),
|
|
||||||
2 * N,
|
|
||||||
device=device,
|
|
||||||
dtype=torch.int64)
|
|
||||||
ab_strides2 = torch.full((w1.size(0), ),
|
|
||||||
N,
|
|
||||||
device=device,
|
|
||||||
dtype=torch.int64)
|
|
||||||
c_strides2 = torch.full((w1.size(0), ),
|
|
||||||
K,
|
|
||||||
device=device,
|
|
||||||
dtype=torch.int64)
|
|
||||||
|
|
||||||
if use_batched_format:
|
|
||||||
c1 = _resize_cache(workspace13, (local_E * padded_M, N * 2))
|
|
||||||
c2 = _resize_cache(workspace2, (local_E * padded_M, N))
|
|
||||||
c3 = _resize_cache(workspace13, (local_E * padded_M, K))
|
|
||||||
else:
|
|
||||||
c1 = _resize_cache(workspace13, (M * topk, N * 2))
|
|
||||||
c2 = _resize_cache(workspace2, (M * topk, N))
|
|
||||||
c3 = _resize_cache(workspace13, (M * topk, K))
|
|
||||||
|
|
||||||
if not per_act_token and (expert_map is not None or use_batched_format):
|
if not per_act_token and (expert_map is not None or use_batched_format):
|
||||||
# this is necessary to avoid imprecise scale calculation caused by
|
# this is necessary to avoid imprecise scale calculation caused by
|
||||||
# random data in the unused workspace. The workspace is unused when
|
# random data in the unused workspace. The workspace is unused when
|
||||||
# this rank handles only partial tokens, or when it is batched .
|
# this rank handles only partial tokens, or when it is batched .
|
||||||
c1.fill_(0)
|
mm1_out.fill_(0)
|
||||||
|
|
||||||
ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale, expert_offsets,
|
ops.cutlass_moe_mm(mm1_out, a1q, w1, a1q_scale, w1_scale, expert_offsets,
|
||||||
problem_sizes1, ab_strides1, ab_strides1, c_strides1,
|
problem_sizes1, ab_strides1, ab_strides1, c_strides1,
|
||||||
per_act_token, per_out_ch)
|
per_act_token, per_out_ch)
|
||||||
|
|
||||||
activation_callable(c2, c1)
|
activation_callable(act_out, mm1_out)
|
||||||
|
|
||||||
a2q, a2q_scale = ops.scaled_fp8_quant(
|
a2q, a2q_scale = ops.scaled_fp8_quant(
|
||||||
c2, a2_scale, use_per_token_if_dynamic=per_act_token)
|
act_out,
|
||||||
|
a2_scale,
|
||||||
|
use_per_token_if_dynamic=per_act_token,
|
||||||
|
output=quant_out)
|
||||||
|
|
||||||
if expert_map is not None:
|
if expert_map is not None:
|
||||||
c3.fill_(0)
|
mm2_out.fill_(0)
|
||||||
|
|
||||||
ops.cutlass_moe_mm(c3, a2q, w2, a2q_scale, w2_scale, expert_offsets,
|
ops.cutlass_moe_mm(mm2_out, a2q, w2, a2q_scale, w2_scale, expert_offsets,
|
||||||
problem_sizes2, ab_strides2, ab_strides2, c_strides2,
|
problem_sizes2, ab_strides2, ab_strides2, c_strides2,
|
||||||
per_act_token, per_out_ch)
|
per_act_token, per_out_ch)
|
||||||
|
|
||||||
if use_batched_format:
|
if use_batched_format:
|
||||||
output.copy_(c3.reshape(local_E, padded_M, K), non_blocking=True)
|
output.copy_(mm2_out.reshape(local_E, padded_M, K), non_blocking=True)
|
||||||
else:
|
else:
|
||||||
# We can't do this inplace because output may point to the same tensor
|
# for non-chunking mode the output is resized from workspace13
|
||||||
# as c3.
|
# so we need to make sure mm2_out uses workspace2.
|
||||||
output.copy_(c3[c_map].view(M * topk, K), non_blocking=True)
|
moe_unpermute(out=output,
|
||||||
|
permuted_hidden_states=mm2_out,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
inv_permuted_idx=inv_perm)
|
||||||
|
|
||||||
|
|
||||||
class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
@ -219,6 +213,10 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
out_dtype: Optional[torch.dtype],
|
out_dtype: Optional[torch.dtype],
|
||||||
per_act_token_quant: bool,
|
per_act_token_quant: bool,
|
||||||
per_out_ch_quant: bool,
|
per_out_ch_quant: bool,
|
||||||
|
ab_strides1: torch.Tensor,
|
||||||
|
ab_strides2: torch.Tensor,
|
||||||
|
c_strides1: torch.Tensor,
|
||||||
|
c_strides2: torch.Tensor,
|
||||||
block_shape: Optional[list[int]] = None,
|
block_shape: Optional[list[int]] = None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@ -229,6 +227,10 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
block_shape=block_shape,
|
block_shape=block_shape,
|
||||||
))
|
))
|
||||||
self.out_dtype = out_dtype
|
self.out_dtype = out_dtype
|
||||||
|
self.ab_strides1 = ab_strides1
|
||||||
|
self.ab_strides2 = ab_strides2
|
||||||
|
self.c_strides1 = c_strides1
|
||||||
|
self.c_strides2 = c_strides2
|
||||||
|
|
||||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
# Let PrepareAndFinalize::finalize() decide the impl.
|
||||||
@ -272,10 +274,11 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
run_cutlass_moe_fp8(
|
run_cutlass_moe_fp8(
|
||||||
output, hidden_states, w1, w2, topk_ids, activation_callable,
|
output, hidden_states, w1, w2, topk_ids, activation_callable,
|
||||||
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
|
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
|
||||||
a2_scale, workspace13, workspace2, expert_num_tokens,
|
a2_scale, self.ab_strides1, self.ab_strides2, self.c_strides1,
|
||||||
|
self.c_strides2, workspace13, workspace2, expert_num_tokens,
|
||||||
self.out_dtype if self.out_dtype is not None else in_dtype,
|
self.out_dtype if self.out_dtype is not None else in_dtype,
|
||||||
self.per_act_token_quant, self.per_out_ch_quant,
|
self.per_act_token_quant, self.per_out_ch_quant,
|
||||||
use_batched_format)
|
use_batched_format, topk_weights)
|
||||||
|
|
||||||
|
|
||||||
class CutlassExpertsFp8(CutlassExpertsFp8Base):
|
class CutlassExpertsFp8(CutlassExpertsFp8Base):
|
||||||
@ -285,12 +288,20 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
|
|||||||
out_dtype: Optional[torch.dtype],
|
out_dtype: Optional[torch.dtype],
|
||||||
per_act_token_quant: bool,
|
per_act_token_quant: bool,
|
||||||
per_out_ch_quant: bool,
|
per_out_ch_quant: bool,
|
||||||
|
ab_strides1: torch.Tensor,
|
||||||
|
ab_strides2: torch.Tensor,
|
||||||
|
c_strides1: torch.Tensor,
|
||||||
|
c_strides2: torch.Tensor,
|
||||||
block_shape: Optional[list[int]] = None,
|
block_shape: Optional[list[int]] = None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
out_dtype,
|
out_dtype,
|
||||||
per_act_token_quant,
|
per_act_token_quant,
|
||||||
per_out_ch_quant,
|
per_out_ch_quant,
|
||||||
|
ab_strides1,
|
||||||
|
ab_strides2,
|
||||||
|
c_strides1,
|
||||||
|
c_strides2,
|
||||||
block_shape,
|
block_shape,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -307,6 +318,10 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
|
|||||||
def supports_expert_map(self) -> bool:
|
def supports_expert_map(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||||
|
# topk weights and reduction are fused in moe_unpermute cuda kernel
|
||||||
|
return TopKWeightAndReduceNoOP()
|
||||||
|
|
||||||
def workspace_shapes(
|
def workspace_shapes(
|
||||||
self,
|
self,
|
||||||
a: torch.Tensor,
|
a: torch.Tensor,
|
||||||
@ -320,8 +335,8 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
|
|||||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||||
workspace1 = (M * topk, max(N, K))
|
workspace1 = (M * topk, max(N, K))
|
||||||
workspace2 = (M * topk, N // 2)
|
workspace2 = (M * topk, max(N // 2, K))
|
||||||
output = (M * topk, K)
|
output = (M, K)
|
||||||
return (workspace1, workspace2, output,
|
return (workspace1, workspace2, output,
|
||||||
self.out_dtype if self.out_dtype is not None else a.dtype)
|
self.out_dtype if self.out_dtype is not None else a.dtype)
|
||||||
|
|
||||||
@ -335,12 +350,20 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
|
|||||||
out_dtype: Optional[torch.dtype],
|
out_dtype: Optional[torch.dtype],
|
||||||
per_act_token_quant: bool,
|
per_act_token_quant: bool,
|
||||||
per_out_ch_quant: bool,
|
per_out_ch_quant: bool,
|
||||||
|
ab_strides1: torch.Tensor,
|
||||||
|
ab_strides2: torch.Tensor,
|
||||||
|
c_strides1: torch.Tensor,
|
||||||
|
c_strides2: torch.Tensor,
|
||||||
block_shape: Optional[list[int]] = None,
|
block_shape: Optional[list[int]] = None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
out_dtype,
|
out_dtype,
|
||||||
per_act_token_quant,
|
per_act_token_quant,
|
||||||
per_out_ch_quant,
|
per_out_ch_quant,
|
||||||
|
ab_strides1,
|
||||||
|
ab_strides2,
|
||||||
|
c_strides1,
|
||||||
|
c_strides2,
|
||||||
block_shape,
|
block_shape,
|
||||||
)
|
)
|
||||||
assert max_experts_per_worker > 0
|
assert max_experts_per_worker > 0
|
||||||
@ -378,7 +401,8 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
|
|||||||
assert num_dp is not None
|
assert num_dp is not None
|
||||||
workspace1 = (self.max_experts_per_worker, padded_M * num_dp,
|
workspace1 = (self.max_experts_per_worker, padded_M * num_dp,
|
||||||
max(N, K))
|
max(N, K))
|
||||||
workspace2 = (self.max_experts_per_worker, padded_M * num_dp, (N // 2))
|
workspace2 = (self.max_experts_per_worker, padded_M * num_dp,
|
||||||
|
max(N // 2, K))
|
||||||
output = (self.max_experts_per_worker, padded_M, K)
|
output = (self.max_experts_per_worker, padded_M, K)
|
||||||
return (workspace1, workspace2, output,
|
return (workspace1, workspace2, output,
|
||||||
self.out_dtype if self.out_dtype is not None else a.dtype)
|
self.out_dtype if self.out_dtype is not None else a.dtype)
|
||||||
@ -392,6 +416,10 @@ def cutlass_moe_fp8(
|
|||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
w1_scale: torch.Tensor,
|
w1_scale: torch.Tensor,
|
||||||
w2_scale: torch.Tensor,
|
w2_scale: torch.Tensor,
|
||||||
|
ab_strides1: torch.Tensor,
|
||||||
|
ab_strides2: torch.Tensor,
|
||||||
|
c_strides1: torch.Tensor,
|
||||||
|
c_strides2: torch.Tensor,
|
||||||
per_act_token: Optional[bool] = None,
|
per_act_token: Optional[bool] = None,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
@ -419,6 +447,17 @@ def cutlass_moe_fp8(
|
|||||||
Shape: [num_experts] or [num_experts, 2N]
|
Shape: [num_experts] or [num_experts, 2N]
|
||||||
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
|
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
|
||||||
Shape: [num_experts] or [num_experts, K]
|
Shape: [num_experts] or [num_experts, K]
|
||||||
|
- ab_strides1 (torch.Tensor): The input/weight strides for the first gemm.
|
||||||
|
Shape: [num_experts]
|
||||||
|
- ab_strides2 (torch.Tensor): The input/weight strides for the second gemm.
|
||||||
|
Shape: [num_experts]
|
||||||
|
- c_strides1 (torch.Tensor): The output strides for the first gemm.
|
||||||
|
Shape: [num_experts]
|
||||||
|
- c_strides2 (torch.Tensor): The output strides for the second gemm.
|
||||||
|
Shape: [num_experts]
|
||||||
|
- per_act_token (Optional[bool]): Whether the scale is per-token or
|
||||||
|
per-tensor.
|
||||||
|
- activation (str): The activation function to use.
|
||||||
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
|
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
|
||||||
Shape: scalar or [M]
|
Shape: scalar or [M]
|
||||||
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
|
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
|
||||||
@ -450,6 +489,10 @@ def cutlass_moe_fp8(
|
|||||||
out_dtype=a.dtype,
|
out_dtype=a.dtype,
|
||||||
per_act_token_quant=per_act_token,
|
per_act_token_quant=per_act_token,
|
||||||
per_out_ch_quant=per_out_ch,
|
per_out_ch_quant=per_out_ch,
|
||||||
|
ab_strides1=ab_strides1,
|
||||||
|
ab_strides2=ab_strides2,
|
||||||
|
c_strides1=c_strides1,
|
||||||
|
c_strides2=c_strides2,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -82,7 +82,8 @@ def moe_permute(
|
|||||||
n_local_expert: int = -1,
|
n_local_expert: int = -1,
|
||||||
expert_map: Optional[torch.Tensor] = None,
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
align_block_size: Optional[int] = None,
|
align_block_size: Optional[int] = None,
|
||||||
fill_invalid_expert: int = -1
|
fill_invalid_expert: int = -1,
|
||||||
|
permuted_hidden_states: Optional[torch.Tensor] = None,
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
|
||||||
torch.Tensor]:
|
torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
@ -95,14 +96,17 @@ def moe_permute(
|
|||||||
- n_expert (int): The number of expert.
|
- n_expert (int): The number of expert.
|
||||||
- n_local_expert (int): The number of expert in current EP rank.
|
- n_local_expert (int): The number of expert in current EP rank.
|
||||||
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
|
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
|
||||||
from the global expert space to the local expert space of the expert
|
from the global expert space to the local expert space of the expert
|
||||||
parallel shard.
|
parallel shard.
|
||||||
- align_block_size (Optional[int]): align group gemm block size for deepgemm
|
- align_block_size (Optional[int]): align group gemm block size for deepgemm
|
||||||
- fill_invalid_expert(int): fill expert id in m_indices for invalid expert
|
- fill_invalid_expert(int): fill expert id in m_indices for invalid expert
|
||||||
to workaround DeepGemm unsupported -1 in m_indices
|
to workaround DeepGemm unsupported -1 in m_indices
|
||||||
|
- permuted_hidden_states (Optional[torch.Tensor]): Optional output tensor.
|
||||||
|
If None, the output tensor will be created in this function.
|
||||||
Returns:
|
Returns:
|
||||||
- permuted_hidden_states (torch.Tensor): permuted activation.
|
- permuted_hidden_states (torch.Tensor): permuted activation.
|
||||||
- a1q_scale (Optional[torch.Tensor]): quant scale for hidden_states
|
- a1q_scale (Optional[torch.Tensor]): permuted quant scale for hidden_states
|
||||||
|
if original scale not per-tensor scaling
|
||||||
- expert_first_token_offset (torch.Tensor): offset of the first token
|
- expert_first_token_offset (torch.Tensor): offset of the first token
|
||||||
of each expert for standard grouped gemm. if enable 'align_block_size'
|
of each expert for standard grouped gemm. if enable 'align_block_size'
|
||||||
expert_first_token_offset will align up to 'align_block_size'.
|
expert_first_token_offset will align up to 'align_block_size'.
|
||||||
@ -122,11 +126,16 @@ def moe_permute(
|
|||||||
1) // align_block_size * align_block_size
|
1) // align_block_size * align_block_size
|
||||||
if n_local_expert == -1:
|
if n_local_expert == -1:
|
||||||
n_local_expert = n_expert
|
n_local_expert = n_expert
|
||||||
permuted_hidden_states = torch.empty(
|
if permuted_hidden_states is None:
|
||||||
(permuted_row_size, n_hidden),
|
permuted_hidden_states = torch.empty(
|
||||||
dtype=hidden_states.dtype,
|
(permuted_row_size, n_hidden),
|
||||||
device=hidden_states.device,
|
dtype=hidden_states.dtype,
|
||||||
)
|
device=hidden_states.device,
|
||||||
|
)
|
||||||
|
assert permuted_hidden_states.size() == (permuted_row_size, n_hidden), (
|
||||||
|
f"Expected permuted hidden states to be {(permuted_row_size, n_hidden)}"
|
||||||
|
f" but got {permuted_hidden_states.size()}")
|
||||||
|
|
||||||
token_expert_indices = torch.arange(0,
|
token_expert_indices = torch.arange(0,
|
||||||
n_token * topk,
|
n_token * topk,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
@ -153,7 +162,8 @@ def moe_permute(
|
|||||||
align_block_size, permuted_hidden_states,
|
align_block_size, permuted_hidden_states,
|
||||||
expert_first_token_offset, inv_permuted_idx,
|
expert_first_token_offset, inv_permuted_idx,
|
||||||
permuted_idx, m_indices)
|
permuted_idx, m_indices)
|
||||||
if a1q_scale is not None:
|
|
||||||
|
if a1q_scale is not None and a1q_scale.dim() > 1:
|
||||||
a1q_scale = a1q_scale[permuted_idx.clamp(max=n_token * topk - 1) //
|
a1q_scale = a1q_scale[permuted_idx.clamp(max=n_token * topk - 1) //
|
||||||
topk]
|
topk]
|
||||||
return (permuted_hidden_states, a1q_scale, expert_first_token_offset,
|
return (permuted_hidden_states, a1q_scale, expert_first_token_offset,
|
||||||
@ -185,6 +195,7 @@ def moe_unpermute(
|
|||||||
n_hidden = permuted_hidden_states.size(-1)
|
n_hidden = permuted_hidden_states.size(-1)
|
||||||
assert (n_hidden * permuted_hidden_states.element_size()
|
assert (n_hidden * permuted_hidden_states.element_size()
|
||||||
) % 16 == 0, "unpermue kernel need hidden dim align to 16B"
|
) % 16 == 0, "unpermue kernel need hidden dim align to 16B"
|
||||||
|
|
||||||
torch.ops._moe_C.moe_unpermute(permuted_hidden_states, topk_weights,
|
torch.ops._moe_C.moe_unpermute(permuted_hidden_states, topk_weights,
|
||||||
inv_permuted_idx, expert_first_token_offset,
|
inv_permuted_idx, expert_first_token_offset,
|
||||||
topk, out)
|
topk, out)
|
||||||
|
|||||||
@ -669,6 +669,25 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||||
self.fused_experts_func = fused_experts
|
self.fused_experts_func = fused_experts
|
||||||
|
|
||||||
|
if self.use_cutlass:
|
||||||
|
device = layer.w13_weight.device
|
||||||
|
# ab_strides1 and c_strides2 are the same
|
||||||
|
self.ab_strides1_c_strides2 = torch.full(
|
||||||
|
(layer.local_num_experts, ),
|
||||||
|
layer.hidden_size,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64)
|
||||||
|
self.ab_strides2 = torch.full(
|
||||||
|
(layer.local_num_experts, ),
|
||||||
|
layer.intermediate_size_per_partition,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64)
|
||||||
|
self.c_strides1 = torch.full(
|
||||||
|
(layer.local_num_experts, ),
|
||||||
|
2 * layer.intermediate_size_per_partition,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64)
|
||||||
|
|
||||||
def select_gemm_impl(
|
def select_gemm_impl(
|
||||||
self,
|
self,
|
||||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||||
@ -693,6 +712,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
moe.in_dtype,
|
moe.in_dtype,
|
||||||
self.input_quant.strategy == QuantizationStrategy.TOKEN,
|
self.input_quant.strategy == QuantizationStrategy.TOKEN,
|
||||||
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
|
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
|
||||||
|
ab_strides1=self.ab_strides1_c_strides2,
|
||||||
|
ab_strides2=self.ab_strides2,
|
||||||
|
c_strides1=self.c_strides1,
|
||||||
|
c_strides2=self.ab_strides1_c_strides2,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__)
|
logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__)
|
||||||
@ -700,6 +723,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
moe.in_dtype,
|
moe.in_dtype,
|
||||||
self.input_quant.strategy == QuantizationStrategy.TOKEN,
|
self.input_quant.strategy == QuantizationStrategy.TOKEN,
|
||||||
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
|
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
|
||||||
|
ab_strides1=self.ab_strides1_c_strides2,
|
||||||
|
ab_strides2=self.ab_strides2,
|
||||||
|
c_strides1=self.c_strides1,
|
||||||
|
c_strides2=self.ab_strides1_c_strides2,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.disable_expert_map = (num_dispatchers > 1
|
self.disable_expert_map = (num_dispatchers > 1
|
||||||
@ -822,6 +849,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
expert_map=None if self.disable_expert_map else expert_map,
|
expert_map=None if self.disable_expert_map else expert_map,
|
||||||
w1_scale=layer.w13_weight_scale,
|
w1_scale=layer.w13_weight_scale,
|
||||||
w2_scale=layer.w2_weight_scale,
|
w2_scale=layer.w2_weight_scale,
|
||||||
|
ab_strides1=self.ab_strides1_c_strides2,
|
||||||
|
ab_strides2=self.ab_strides2,
|
||||||
|
c_strides1=self.c_strides1,
|
||||||
|
c_strides2=self.ab_strides1_c_strides2,
|
||||||
a1_scale=layer.w13_input_scale,
|
a1_scale=layer.w13_input_scale,
|
||||||
a2_scale=layer.w2_input_scale,
|
a2_scale=layer.w2_input_scale,
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user