[Kernel] CUTLASS MoE FP8: Integrate cuda moe permute/unpermute (#23045)

Signed-off-by: Shixian Cui <shixian@amazon.com>
This commit is contained in:
shixianc 2025-08-20 07:35:26 -07:00 committed by GitHub
parent 4449235843
commit b17109beea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 369 additions and 121 deletions

View File

@ -80,6 +80,11 @@ def bench_run(
a, score, topk, renormalize=False a, score, topk, renormalize=False
) )
ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
def run_triton_moe( def run_triton_moe(
a: torch.Tensor, a: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
@ -111,6 +116,10 @@ def bench_run(
w2: torch.Tensor, w2: torch.Tensor,
w1_scale: torch.Tensor, w1_scale: torch.Tensor,
w2_scale: torch.Tensor, w2_scale: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
per_act_token: bool, per_act_token: bool,
@ -125,6 +134,10 @@ def bench_run(
topk_ids, topk_ids,
w1_scale, w1_scale,
w2_scale, w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
per_act_token, per_act_token,
a1_scale=None, a1_scale=None,
) )
@ -136,6 +149,10 @@ def bench_run(
w2_q: torch.Tensor, w2_q: torch.Tensor,
w1_scale: torch.Tensor, w1_scale: torch.Tensor,
w2_scale: torch.Tensor, w2_scale: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
): ):
@ -150,6 +167,10 @@ def bench_run(
topk_ids, topk_ids,
w1_scale, w1_scale,
w2_scale, w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
per_act_token, per_act_token,
a1_scale=None, a1_scale=None,
) )
@ -194,6 +215,10 @@ def bench_run(
w2_q, w2_q,
w1_scale, w1_scale,
w2_scale, w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
topk_weights, topk_weights,
topk_ids, topk_ids,
) )
@ -231,6 +256,10 @@ def bench_run(
"w1_scale": w1_scale, "w1_scale": w1_scale,
"w2_scale": w2_scale, "w2_scale": w2_scale,
"per_act_token": per_act_token, "per_act_token": per_act_token,
"ab_strides1": ab_strides1,
"ab_strides2": ab_strides2,
"c_strides1": c_strides1,
"c_strides2": c_strides2,
# cuda graph params # cuda graph params
"cutlass_graph": cutlass_graph, "cutlass_graph": cutlass_graph,
"triton_graph": triton_graph, "triton_graph": triton_graph,
@ -289,6 +318,10 @@ def bench_run(
w2_q, w2_q,
w1_scale, w1_scale,
w2_scale, w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
topk_weights, topk_weights,
topk_ids, topk_ids,
per_act_token, per_act_token,
@ -297,7 +330,7 @@ def bench_run(
results.append( results.append(
benchmark.Timer( benchmark.Timer(
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501 stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
globals=globals, globals=globals,
label=label, label=label,
sub_label=sub_label, sub_label=sub_label,

View File

@ -45,8 +45,6 @@ void moe_permute(
auto copy_topk_ids = topk_ids.clone(); // copy topk_ids for preprocess auto copy_topk_ids = topk_ids.clone(); // copy topk_ids for preprocess
auto permuted_experts_id = torch::empty_like(topk_ids); auto permuted_experts_id = torch::empty_like(topk_ids);
auto sorted_row_idx = torch::empty_like(inv_permuted_idx); auto sorted_row_idx = torch::empty_like(inv_permuted_idx);
auto align_expert_first_token_offset =
torch::zeros_like(expert_first_token_offset);
CubKeyValueSorter sorter{}; CubKeyValueSorter sorter{};
int64_t* valid_num_ptr = nullptr; int64_t* valid_num_ptr = nullptr;
@ -85,12 +83,14 @@ void moe_permute(
}); });
// get m_indices and update expert_first_token_offset with align block // get m_indices and update expert_first_token_offset with align block
getMIndices(get_ptr<int64_t>(expert_first_token_offset), // this is only required for DeepGemm and not required for CUTLASS group gemm
get_ptr<int64_t>(align_expert_first_token_offset),
get_ptr<int>(m_indices), n_local_expert, align_block_size_value,
stream);
if (align_block_size.has_value()) { if (align_block_size.has_value()) {
// update align_expert_first_token_offset auto align_expert_first_token_offset =
torch::zeros_like(expert_first_token_offset);
getMIndices(get_ptr<int64_t>(expert_first_token_offset),
get_ptr<int64_t>(align_expert_first_token_offset),
get_ptr<int>(m_indices), n_local_expert, align_block_size_value,
stream);
expert_first_token_offset.copy_(align_expert_first_token_offset); expert_first_token_offset.copy_(align_expert_first_token_offset);
} }
} }
@ -195,19 +195,14 @@ void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights,
torch::Tensor& expert_first_token_offset, torch::Tensor& expert_first_token_offset,
torch::Tensor& src_row_id2dst_row_id_map, torch::Tensor& src_row_id2dst_row_id_map,
torch::Tensor& m_indices) { torch::Tensor& m_indices) {
TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0"); TORCH_CHECK(false, "moe_permute is not supported on CUDA < 12.0");
} }
void moe_unpermute(const torch::Tensor& input, void moe_unpermute(
const torch::Tensor& topk_weights, torch::Tensor& topk_ids, const torch::Tensor& permuted_hidden_states,
const torch::Tensor& token_expert_indices, const torch::Tensor& topk_weights, const torch::Tensor& inv_permuted_idx,
const std::optional<torch::Tensor>& expert_map, const std::optional<torch::Tensor>& expert_first_token_offset, int64_t topk,
int64_t n_expert, int64_t n_local_expert, int64_t topk, torch::Tensor& hidden_states) {
const std::optional<int64_t>& align_block_size,
torch::Tensor& permuted_input,
torch::Tensor& expert_first_token_offset,
torch::Tensor& src_row_id2dst_row_id_map,
torch::Tensor& m_indices) {
TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0"); TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0");
} }

View File

@ -229,6 +229,11 @@ void get_cutlass_moe_mm_data(
const int64_t num_experts, const int64_t n, const int64_t k, const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets); const std::optional<torch::Tensor>& blockscale_offsets);
void get_cutlass_moe_mm_problem_sizes(
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets);
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2, torch::Tensor& problem_sizes2,

View File

@ -10,7 +10,7 @@
template <typename ElementAB, typename ElementC, typename ElementAccumulator> template <typename ElementAB, typename ElementC, typename ElementAccumulator>
__global__ void get_group_gemm_starts( __global__ void get_group_gemm_starts(
int32_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets, int64_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets,
ElementC** out_offsets, ElementAccumulator** a_scales_offsets, ElementC** out_offsets, ElementAccumulator** a_scales_offsets,
ElementAccumulator** b_scales_offsets, ElementAB* a_base_as_int, ElementAccumulator** b_scales_offsets, ElementAB* a_base_as_int,
ElementAB* b_base_as_int, ElementC* out_base_as_int, ElementAB* b_base_as_int, ElementC* out_base_as_int,
@ -34,7 +34,7 @@ __global__ void get_group_gemm_starts(
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \ else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
get_group_gemm_starts<cutlass::float_e4m3_t, C_TYPE, float> \ get_group_gemm_starts<cutlass::float_e4m3_t, C_TYPE, float> \
<<<1, num_experts, 0, stream>>>( \ <<<1, num_experts, 0, stream>>>( \
static_cast<int32_t*>(expert_offsets.data_ptr()), \ static_cast<int64_t*>(expert_offsets.data_ptr()), \
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \ static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()), \ static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()), \
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \ static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
@ -61,6 +61,8 @@ void run_get_group_gemm_starts(
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
// expect int64_t to avoid overflow during offset calculations
TORCH_CHECK(expert_offsets.dtype() == torch::kInt64);
int num_experts = static_cast<int>(expert_offsets.size(0)); int num_experts = static_cast<int>(expert_offsets.size(0));
bool per_act_token = a_scales.numel() != 1; bool per_act_token = a_scales.numel() != 1;

View File

@ -104,6 +104,53 @@ __global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids,
} }
} }
namespace {
inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
torch::Tensor& atomic_buffer,
int64_t num_experts, int64_t n,
int64_t k, cudaStream_t stream,
const bool swap_ab) {
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
const int32_t* topk_ptr = static_cast<const int32_t*>(topk_ids.data_ptr());
int32_t* ps1_ptr = static_cast<int32_t*>(problem_sizes1.data_ptr());
int32_t* ps2_ptr = static_cast<int32_t*>(problem_sizes2.data_ptr());
int32_t* atomic_ptr = static_cast<int32_t*>(atomic_buffer.data_ptr());
if (swap_ab) {
compute_problem_sizes<true><<<num_experts, num_threads, 0, stream>>>(
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
static_cast<int>(topk_ids.numel()), static_cast<int>(n),
static_cast<int>(k));
} else {
compute_problem_sizes<false><<<num_experts, num_threads, 0, stream>>>(
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
static_cast<int>(topk_ids.numel()), static_cast<int>(n),
static_cast<int>(k));
}
}
} // namespace
void get_cutlass_moe_mm_problem_sizes_caller(
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets) {
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
auto options_int32 =
torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
// Swap-AB should be disabled for FP4 path
bool may_swap_ab = (!blockscale_offsets.has_value()) &&
(topk_ids.numel() <= SWAP_AB_THRESHOLD);
launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2,
atomic_buffer, num_experts, n, k, stream,
may_swap_ab);
}
void get_cutlass_moe_mm_data_caller( void get_cutlass_moe_mm_data_caller(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
@ -121,21 +168,9 @@ void get_cutlass_moe_mm_data_caller(
bool may_swap_ab = (!blockscale_offsets.has_value()) && bool may_swap_ab = (!blockscale_offsets.has_value()) &&
(topk_ids.numel() <= SWAP_AB_THRESHOLD); (topk_ids.numel() <= SWAP_AB_THRESHOLD);
if (may_swap_ab) { launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2,
compute_problem_sizes<true><<<num_experts, num_threads, 0, stream>>>( atomic_buffer, num_experts, n, k, stream,
static_cast<const int32_t*>(topk_ids.data_ptr()), may_swap_ab);
static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n,
k);
} else {
compute_problem_sizes<false><<<num_experts, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n,
k);
}
if (blockscale_offsets.has_value()) { if (blockscale_offsets.has_value()) {
// fp4 path // fp4 path

View File

@ -76,6 +76,11 @@ void get_cutlass_moe_mm_data_caller(
const int64_t num_experts, const int64_t n, const int64_t k, const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets); const std::optional<torch::Tensor>& blockscale_offsets);
void get_cutlass_moe_mm_problem_sizes_caller(
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets);
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets, void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2, torch::Tensor& problem_sizes2,
@ -293,6 +298,25 @@ void get_cutlass_moe_mm_data(
version_num, ". Required capability: 90 or 100"); version_num, ". Required capability: 90 or 100");
} }
void get_cutlass_moe_mm_problem_sizes(
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets) {
int32_t version_num = get_sm_version_num();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100)
get_cutlass_moe_mm_problem_sizes_caller(topk_ids, problem_sizes1,
problem_sizes2, num_experts, n, k,
blockscale_offsets);
return;
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled get_cutlass_moe_mm_problem_sizes: no cutlass_scaled_mm "
"kernel for CUDA device capability: ",
version_num, ". Required capability: 90 or 100");
}
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2, torch::Tensor& problem_sizes2,

View File

@ -440,6 +440,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
{stride_tag}); {stride_tag});
ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data); ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data);
// A function that computes problem sizes for each expert's multiplication
// used by the two mms called from fused MoE operation. It takes topk_ids as
// an input, and computes problem_sizes1 and problem_sizes2 only.
ops.def(
"get_cutlass_moe_mm_problem_sizes(Tensor topk_ids, "
" Tensor! problem_sizes1, "
" Tensor! problem_sizes2, "
" int num_experts, int n, int k, "
" Tensor? blockscale_offsets) -> ()",
{stride_tag});
ops.impl("get_cutlass_moe_mm_problem_sizes", torch::kCUDA,
&get_cutlass_moe_mm_problem_sizes);
// A function that computes data required to run fused MoE with w8a8 grouped // A function that computes data required to run fused MoE with w8a8 grouped
// GEMM and PPLX. It takes expert_num_tokens and non_zero_expert_idxs // GEMM and PPLX. It takes expert_num_tokens and non_zero_expert_idxs
// as an input, and computes expert_offsets (token start indices of each // as an input, and computes expert_offsets (token start indices of each

View File

@ -207,6 +207,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
'topk_ids': topk_ids, 'topk_ids': topk_ids,
'w1_scale': moe_tensors.w1_scale, 'w1_scale': moe_tensors.w1_scale,
'w2_scale': moe_tensors.w2_scale, 'w2_scale': moe_tensors.w2_scale,
'ab_strides1': moe_tensors.ab_strides1,
'ab_strides2': moe_tensors.ab_strides2,
'c_strides1': moe_tensors.c_strides1,
'c_strides2': moe_tensors.c_strides2,
'per_act_token': per_act_token, 'per_act_token': per_act_token,
'a1_scale': None #moe_tensors.a_scale 'a1_scale': None #moe_tensors.a_scale
} }
@ -424,8 +428,8 @@ def test_run_cutlass_moe_fp8(
topk_ids[0][1] = 1 topk_ids[0][1] = 1
workspace13_shape = (m * topk, max(2 * n, k)) workspace13_shape = (m * topk, max(2 * n, k))
workspace2_shape = (m * topk, n) workspace2_shape = (m * topk, max(n, k))
output_shape = (m * topk, k) output_shape = (m, k)
workspace13 = torch.empty(prod(workspace13_shape), workspace13 = torch.empty(prod(workspace13_shape),
device="cuda", device="cuda",
@ -440,6 +444,11 @@ def test_run_cutlass_moe_fp8(
expert_map[start:end] = list(range(num_local_experts)) expert_map[start:end] = list(range(num_local_experts))
expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda") expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
activation = lambda o, i: torch.ops._C.silu_and_mul(o, i) activation = lambda o, i: torch.ops._C.silu_and_mul(o, i)
a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale, a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale,
torch.float8_e4m3fn, torch.float8_e4m3fn,
@ -448,8 +457,9 @@ def test_run_cutlass_moe_fp8(
func = lambda output: run_cutlass_moe_fp8( func = lambda output: run_cutlass_moe_fp8(
output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation, output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation,
global_num_experts, expert_map, mt.w1_scale, mt.w2_scale, global_num_experts, expert_map, mt.w1_scale, mt.w2_scale,
a1q_scale, None, workspace13, workspace2, None, mt.a.dtype, a1q_scale, None, ab_strides1, ab_strides2, c_strides1, c_strides2,
per_act_token, per_out_channel, False) workspace13, workspace2, None, mt.a.dtype, per_act_token,
per_out_channel, False, topk_weights)
workspace13.random_() workspace13.random_()
output_random_workspace = torch.empty(output_shape, output_random_workspace = torch.empty(output_shape,

View File

@ -238,7 +238,11 @@ def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int,
atol=0, atol=0,
rtol=0) rtol=0)
# check mindice # check mindice
torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0) # current kernel usage assumes deepgemm requires align_block_size
# when it's not provided then we don't compute m_indices (for cutlass)
if align_block_size is not None:
torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0)
# check permuted_hidden_states, only valid token # check permuted_hidden_states, only valid token
torch.testing.assert_close(gold_permuted_hidden_states[valid_row_idx], torch.testing.assert_close(gold_permuted_hidden_states[valid_row_idx],
permuted_hidden_states[valid_row_idx], permuted_hidden_states[valid_row_idx],

View File

@ -76,6 +76,7 @@ def pplx_cutlass_moe(
assert torch.cuda.current_device() == pgi.local_rank assert torch.cuda.current_device() == pgi.local_rank
num_tokens, hidden_dim = a.shape num_tokens, hidden_dim = a.shape
intermediate_dim = w2.shape[2]
num_experts = w1.shape[0] num_experts = w1.shape[0]
block_size = hidden_dim # TODO support more cases block_size = hidden_dim # TODO support more cases
device = pgi.device device = pgi.device
@ -124,8 +125,27 @@ def pplx_cutlass_moe(
num_local_experts=num_local_experts, num_local_experts=num_local_experts,
num_dispatchers=num_dispatchers) num_dispatchers=num_dispatchers)
ab_strides1 = torch.full((num_local_experts, ),
hidden_dim,
device="cuda",
dtype=torch.int64)
ab_strides2 = torch.full((num_local_experts, ),
intermediate_dim,
device="cuda",
dtype=torch.int64)
c_strides1 = torch.full((num_local_experts, ),
2 * intermediate_dim,
device="cuda",
dtype=torch.int64)
c_strides2 = torch.full((num_local_experts, ),
hidden_dim,
device="cuda",
dtype=torch.int64)
experts = CutlassBatchedExpertsFp8(num_local_experts, num_dispatchers, experts = CutlassBatchedExpertsFp8(num_local_experts, num_dispatchers,
out_dtype, per_act_token, per_out_ch) out_dtype, per_act_token, per_out_ch,
ab_strides1, ab_strides2, c_strides1,
c_strides2)
fused_cutlass_experts = FusedMoEModularKernel( fused_cutlass_experts = FusedMoEModularKernel(
prepare_finalize, prepare_finalize,

View File

@ -535,7 +535,7 @@ def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool,
expert_offsets = torch.zeros((num_experts + 1), expert_offsets = torch.zeros((num_experts + 1),
device=device, device=device,
dtype=torch.int32) dtype=torch.int64)
problem_sizes = torch.zeros((num_experts, 3), problem_sizes = torch.zeros((num_experts, 3),
device=device, device=device,

View File

@ -844,6 +844,28 @@ def get_cutlass_moe_mm_data(topk_ids: torch.Tensor,
blockscale_offsets) blockscale_offsets)
def get_cutlass_moe_mm_problem_sizes(
topk_ids: torch.Tensor,
problem_sizes1: torch.Tensor,
problem_sizes2: torch.Tensor,
num_experts: int,
n: int,
k: int,
blockscale_offsets: Optional[torch.Tensor] = None):
"""
Compute only the per-expert problem sizes needed by the two grouped matrix
multiplications used in CUTLASS-based fused MoE.
The function takes in topk_ids (tokenexpert mapping) and computes:
- problem_sizes1, problem_sizes2: M×N×K sizes of each expert's
multiplication for the two grouped MMs
used in the fused MoE operation.
"""
return torch.ops._C.get_cutlass_moe_mm_problem_sizes(
topk_ids, problem_sizes1, problem_sizes2, num_experts, n, k,
blockscale_offsets)
def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor): def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor):
""" """
Shuffle and expand the input tensor according to the dst2src_map and store the result in output_tensor. Shuffle and expand the input tensor according to the dst2src_map and store the result in output_tensor.

View File

@ -9,12 +9,13 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
moe_permute, moe_unpermute)
from vllm.model_executor.layers.fused_moe.prepare_finalize import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP) MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate, TopKWeightAndReduceNoOP) TopKWeightAndReduceDelegate, TopKWeightAndReduceNoOP)
from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm, from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
_fp8_quantize,
_resize_cache) _resize_cache)
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
@ -34,6 +35,10 @@ def run_cutlass_moe_fp8(
w2_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor],
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor], expert_num_tokens: Optional[torch.Tensor],
@ -41,6 +46,7 @@ def run_cutlass_moe_fp8(
per_act_token: bool, per_act_token: bool,
per_out_ch: bool, per_out_ch: bool,
use_batched_format: bool, use_batched_format: bool,
topk_weights: Optional[torch.Tensor],
): ):
a1q = hidden_states a1q = hidden_states
@ -99,6 +105,22 @@ def run_cutlass_moe_fp8(
topk = local_topk_ids.size(1) topk = local_topk_ids.size(1)
local_E = w1.size(0) local_E = w1.size(0)
if use_batched_format:
mm1_out = _resize_cache(workspace13, (local_E * padded_M, N * 2))
act_out = _resize_cache(workspace2, (local_E * padded_M, N))
quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn),
(local_E * padded_M, N))
mm2_out = _resize_cache(workspace2, (local_E * padded_M, K))
else:
a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn),
(M * topk, K))
mm1_out = _resize_cache(workspace13, (M * topk, N * 2))
act_out = _resize_cache(workspace2, (M * topk, N))
# original workspace are based on input hidden_states dtype (bf16)
quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn),
(M * topk, N))
mm2_out = _resize_cache(workspace2, (M * topk, K))
if use_batched_format: if use_batched_format:
assert expert_num_tokens is not None assert expert_num_tokens is not None
@ -120,11 +142,10 @@ def run_cutlass_moe_fp8(
w2_scale = w2_scale.reshape(w2_scale.size(0), -1) w2_scale = w2_scale.reshape(w2_scale.size(0), -1)
a1q = a1q.reshape(-1, a1q.size(2)) a1q = a1q.reshape(-1, a1q.size(2))
a1q_scale = a1q_scale.reshape(-1, a1q_scale.size(2)).contiguous() a1q_scale = a1q_scale.reshape(-1, a1q_scale.size(2)).contiguous()
# c3x get_group_gemm_starts expects int64 to avoid overflow
# during offset calculations
expert_offsets = expert_offsets.to(torch.int64)
else: else:
expert_offsets = torch.empty((global_num_experts + 1),
dtype=torch.int32,
device=device)
problem_sizes1 = torch.empty((global_num_experts, 3), problem_sizes1 = torch.empty((global_num_experts, 3),
dtype=torch.int32, dtype=torch.int32,
device=device) device=device)
@ -132,84 +153,57 @@ def run_cutlass_moe_fp8(
dtype=torch.int32, dtype=torch.int32,
device=device) device=device)
# With expert_map each Rank processes only a subset of experts. As num_expert = global_num_experts if expert_map is None \
# a result not all of a_map and c2 tensors are filled. We fill it else expert_map.size(0)
# zeros for correctness. # permuted a1q reuses workspace2
if expert_map is not None: a1q, a1q_scale, expert_offsets, inv_perm, _ = moe_permute(
a_map = torch.zeros((local_topk_ids.numel()), a1q,
dtype=torch.int32, a1q_scale,
device=device) topk_ids,
else: num_expert,
a_map = torch.empty((local_topk_ids.numel()), local_E,
dtype=torch.int32, expert_map,
device=device) permuted_hidden_states=a1q_perm)
c_map = torch.empty((local_topk_ids.numel()),
dtype=torch.int32,
device=device)
ops.get_cutlass_moe_mm_data(local_topk_ids, expert_offsets,
problem_sizes1, problem_sizes2, a_map,
c_map, global_num_experts, N, K)
a1q = _fp8_perm(a1q, a_map)
a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale
expert_offsets = expert_offsets[:-1] expert_offsets = expert_offsets[:-1]
ab_strides1 = torch.full((w1.size(0), ), ops.get_cutlass_moe_mm_problem_sizes(local_topk_ids, problem_sizes1,
K, problem_sizes2,
device=device, global_num_experts, N, K)
dtype=torch.int64)
c_strides1 = torch.full((w1.size(0), ),
2 * N,
device=device,
dtype=torch.int64)
ab_strides2 = torch.full((w1.size(0), ),
N,
device=device,
dtype=torch.int64)
c_strides2 = torch.full((w1.size(0), ),
K,
device=device,
dtype=torch.int64)
if use_batched_format:
c1 = _resize_cache(workspace13, (local_E * padded_M, N * 2))
c2 = _resize_cache(workspace2, (local_E * padded_M, N))
c3 = _resize_cache(workspace13, (local_E * padded_M, K))
else:
c1 = _resize_cache(workspace13, (M * topk, N * 2))
c2 = _resize_cache(workspace2, (M * topk, N))
c3 = _resize_cache(workspace13, (M * topk, K))
if not per_act_token and (expert_map is not None or use_batched_format): if not per_act_token and (expert_map is not None or use_batched_format):
# this is necessary to avoid imprecise scale calculation caused by # this is necessary to avoid imprecise scale calculation caused by
# random data in the unused workspace. The workspace is unused when # random data in the unused workspace. The workspace is unused when
# this rank handles only partial tokens, or when it is batched . # this rank handles only partial tokens, or when it is batched .
c1.fill_(0) mm1_out.fill_(0)
ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale, expert_offsets, ops.cutlass_moe_mm(mm1_out, a1q, w1, a1q_scale, w1_scale, expert_offsets,
problem_sizes1, ab_strides1, ab_strides1, c_strides1, problem_sizes1, ab_strides1, ab_strides1, c_strides1,
per_act_token, per_out_ch) per_act_token, per_out_ch)
activation_callable(c2, c1) activation_callable(act_out, mm1_out)
a2q, a2q_scale = ops.scaled_fp8_quant( a2q, a2q_scale = ops.scaled_fp8_quant(
c2, a2_scale, use_per_token_if_dynamic=per_act_token) act_out,
a2_scale,
use_per_token_if_dynamic=per_act_token,
output=quant_out)
if expert_map is not None: if expert_map is not None:
c3.fill_(0) mm2_out.fill_(0)
ops.cutlass_moe_mm(c3, a2q, w2, a2q_scale, w2_scale, expert_offsets, ops.cutlass_moe_mm(mm2_out, a2q, w2, a2q_scale, w2_scale, expert_offsets,
problem_sizes2, ab_strides2, ab_strides2, c_strides2, problem_sizes2, ab_strides2, ab_strides2, c_strides2,
per_act_token, per_out_ch) per_act_token, per_out_ch)
if use_batched_format: if use_batched_format:
output.copy_(c3.reshape(local_E, padded_M, K), non_blocking=True) output.copy_(mm2_out.reshape(local_E, padded_M, K), non_blocking=True)
else: else:
# We can't do this inplace because output may point to the same tensor # for non-chunking mode the output is resized from workspace13
# as c3. # so we need to make sure mm2_out uses workspace2.
output.copy_(c3[c_map].view(M * topk, K), non_blocking=True) moe_unpermute(out=output,
permuted_hidden_states=mm2_out,
topk_weights=topk_weights,
inv_permuted_idx=inv_perm)
class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
@ -219,6 +213,10 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
out_dtype: Optional[torch.dtype], out_dtype: Optional[torch.dtype],
per_act_token_quant: bool, per_act_token_quant: bool,
per_out_ch_quant: bool, per_out_ch_quant: bool,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
): ):
super().__init__( super().__init__(
@ -229,6 +227,10 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
block_shape=block_shape, block_shape=block_shape,
)) ))
self.out_dtype = out_dtype self.out_dtype = out_dtype
self.ab_strides1 = ab_strides1
self.ab_strides2 = ab_strides2
self.c_strides1 = c_strides1
self.c_strides2 = c_strides2
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl. # Let PrepareAndFinalize::finalize() decide the impl.
@ -272,10 +274,11 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
run_cutlass_moe_fp8( run_cutlass_moe_fp8(
output, hidden_states, w1, w2, topk_ids, activation_callable, output, hidden_states, w1, w2, topk_ids, activation_callable,
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale, global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
a2_scale, workspace13, workspace2, expert_num_tokens, a2_scale, self.ab_strides1, self.ab_strides2, self.c_strides1,
self.c_strides2, workspace13, workspace2, expert_num_tokens,
self.out_dtype if self.out_dtype is not None else in_dtype, self.out_dtype if self.out_dtype is not None else in_dtype,
self.per_act_token_quant, self.per_out_ch_quant, self.per_act_token_quant, self.per_out_ch_quant,
use_batched_format) use_batched_format, topk_weights)
class CutlassExpertsFp8(CutlassExpertsFp8Base): class CutlassExpertsFp8(CutlassExpertsFp8Base):
@ -285,12 +288,20 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
out_dtype: Optional[torch.dtype], out_dtype: Optional[torch.dtype],
per_act_token_quant: bool, per_act_token_quant: bool,
per_out_ch_quant: bool, per_out_ch_quant: bool,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
): ):
super().__init__( super().__init__(
out_dtype, out_dtype,
per_act_token_quant, per_act_token_quant,
per_out_ch_quant, per_out_ch_quant,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
block_shape, block_shape,
) )
@ -307,6 +318,10 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
return True return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# topk weights and reduction are fused in moe_unpermute cuda kernel
return TopKWeightAndReduceNoOP()
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor, a: torch.Tensor,
@ -320,8 +335,8 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
workspace1 = (M * topk, max(N, K)) workspace1 = (M * topk, max(N, K))
workspace2 = (M * topk, N // 2) workspace2 = (M * topk, max(N // 2, K))
output = (M * topk, K) output = (M, K)
return (workspace1, workspace2, output, return (workspace1, workspace2, output,
self.out_dtype if self.out_dtype is not None else a.dtype) self.out_dtype if self.out_dtype is not None else a.dtype)
@ -335,12 +350,20 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
out_dtype: Optional[torch.dtype], out_dtype: Optional[torch.dtype],
per_act_token_quant: bool, per_act_token_quant: bool,
per_out_ch_quant: bool, per_out_ch_quant: bool,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
): ):
super().__init__( super().__init__(
out_dtype, out_dtype,
per_act_token_quant, per_act_token_quant,
per_out_ch_quant, per_out_ch_quant,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
block_shape, block_shape,
) )
assert max_experts_per_worker > 0 assert max_experts_per_worker > 0
@ -378,7 +401,8 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
assert num_dp is not None assert num_dp is not None
workspace1 = (self.max_experts_per_worker, padded_M * num_dp, workspace1 = (self.max_experts_per_worker, padded_M * num_dp,
max(N, K)) max(N, K))
workspace2 = (self.max_experts_per_worker, padded_M * num_dp, (N // 2)) workspace2 = (self.max_experts_per_worker, padded_M * num_dp,
max(N // 2, K))
output = (self.max_experts_per_worker, padded_M, K) output = (self.max_experts_per_worker, padded_M, K)
return (workspace1, workspace2, output, return (workspace1, workspace2, output,
self.out_dtype if self.out_dtype is not None else a.dtype) self.out_dtype if self.out_dtype is not None else a.dtype)
@ -392,6 +416,10 @@ def cutlass_moe_fp8(
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
w1_scale: torch.Tensor, w1_scale: torch.Tensor,
w2_scale: torch.Tensor, w2_scale: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
per_act_token: Optional[bool] = None, per_act_token: Optional[bool] = None,
activation: str = "silu", activation: str = "silu",
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
@ -419,6 +447,17 @@ def cutlass_moe_fp8(
Shape: [num_experts] or [num_experts, 2N] Shape: [num_experts] or [num_experts, 2N]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts] or [num_experts, K] Shape: [num_experts] or [num_experts, K]
- ab_strides1 (torch.Tensor): The input/weight strides for the first gemm.
Shape: [num_experts]
- ab_strides2 (torch.Tensor): The input/weight strides for the second gemm.
Shape: [num_experts]
- c_strides1 (torch.Tensor): The output strides for the first gemm.
Shape: [num_experts]
- c_strides2 (torch.Tensor): The output strides for the second gemm.
Shape: [num_experts]
- per_act_token (Optional[bool]): Whether the scale is per-token or
per-tensor.
- activation (str): The activation function to use.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a. - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [M] Shape: scalar or [M]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
@ -450,6 +489,10 @@ def cutlass_moe_fp8(
out_dtype=a.dtype, out_dtype=a.dtype,
per_act_token_quant=per_act_token, per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch, per_out_ch_quant=per_out_ch,
ab_strides1=ab_strides1,
ab_strides2=ab_strides2,
c_strides1=c_strides1,
c_strides2=c_strides2,
), ),
) )

View File

@ -82,7 +82,8 @@ def moe_permute(
n_local_expert: int = -1, n_local_expert: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
align_block_size: Optional[int] = None, align_block_size: Optional[int] = None,
fill_invalid_expert: int = -1 fill_invalid_expert: int = -1,
permuted_hidden_states: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, ) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
torch.Tensor]: torch.Tensor]:
""" """
@ -100,9 +101,12 @@ def moe_permute(
- align_block_size (Optional[int]): align group gemm block size for deepgemm - align_block_size (Optional[int]): align group gemm block size for deepgemm
- fill_invalid_expert(int): fill expert id in m_indices for invalid expert - fill_invalid_expert(int): fill expert id in m_indices for invalid expert
to workaround DeepGemm unsupported -1 in m_indices to workaround DeepGemm unsupported -1 in m_indices
- permuted_hidden_states (Optional[torch.Tensor]): Optional output tensor.
If None, the output tensor will be created in this function.
Returns: Returns:
- permuted_hidden_states (torch.Tensor): permuted activation. - permuted_hidden_states (torch.Tensor): permuted activation.
- a1q_scale (Optional[torch.Tensor]): quant scale for hidden_states - a1q_scale (Optional[torch.Tensor]): permuted quant scale for hidden_states
if original scale not per-tensor scaling
- expert_first_token_offset (torch.Tensor): offset of the first token - expert_first_token_offset (torch.Tensor): offset of the first token
of each expert for standard grouped gemm. if enable 'align_block_size' of each expert for standard grouped gemm. if enable 'align_block_size'
expert_first_token_offset will align up to 'align_block_size'. expert_first_token_offset will align up to 'align_block_size'.
@ -122,11 +126,16 @@ def moe_permute(
1) // align_block_size * align_block_size 1) // align_block_size * align_block_size
if n_local_expert == -1: if n_local_expert == -1:
n_local_expert = n_expert n_local_expert = n_expert
permuted_hidden_states = torch.empty( if permuted_hidden_states is None:
(permuted_row_size, n_hidden), permuted_hidden_states = torch.empty(
dtype=hidden_states.dtype, (permuted_row_size, n_hidden),
device=hidden_states.device, dtype=hidden_states.dtype,
) device=hidden_states.device,
)
assert permuted_hidden_states.size() == (permuted_row_size, n_hidden), (
f"Expected permuted hidden states to be {(permuted_row_size, n_hidden)}"
f" but got {permuted_hidden_states.size()}")
token_expert_indices = torch.arange(0, token_expert_indices = torch.arange(0,
n_token * topk, n_token * topk,
dtype=torch.int32, dtype=torch.int32,
@ -153,7 +162,8 @@ def moe_permute(
align_block_size, permuted_hidden_states, align_block_size, permuted_hidden_states,
expert_first_token_offset, inv_permuted_idx, expert_first_token_offset, inv_permuted_idx,
permuted_idx, m_indices) permuted_idx, m_indices)
if a1q_scale is not None:
if a1q_scale is not None and a1q_scale.dim() > 1:
a1q_scale = a1q_scale[permuted_idx.clamp(max=n_token * topk - 1) // a1q_scale = a1q_scale[permuted_idx.clamp(max=n_token * topk - 1) //
topk] topk]
return (permuted_hidden_states, a1q_scale, expert_first_token_offset, return (permuted_hidden_states, a1q_scale, expert_first_token_offset,
@ -185,6 +195,7 @@ def moe_unpermute(
n_hidden = permuted_hidden_states.size(-1) n_hidden = permuted_hidden_states.size(-1)
assert (n_hidden * permuted_hidden_states.element_size() assert (n_hidden * permuted_hidden_states.element_size()
) % 16 == 0, "unpermue kernel need hidden dim align to 16B" ) % 16 == 0, "unpermue kernel need hidden dim align to 16B"
torch.ops._moe_C.moe_unpermute(permuted_hidden_states, topk_weights, torch.ops._moe_C.moe_unpermute(permuted_hidden_states, topk_weights,
inv_permuted_idx, expert_first_token_offset, inv_permuted_idx, expert_first_token_offset,
topk, out) topk, out)

View File

@ -669,6 +669,25 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
self.fused_experts_func = fused_experts self.fused_experts_func = fused_experts
if self.use_cutlass:
device = layer.w13_weight.device
# ab_strides1 and c_strides2 are the same
self.ab_strides1_c_strides2 = torch.full(
(layer.local_num_experts, ),
layer.hidden_size,
device=device,
dtype=torch.int64)
self.ab_strides2 = torch.full(
(layer.local_num_experts, ),
layer.intermediate_size_per_partition,
device=device,
dtype=torch.int64)
self.c_strides1 = torch.full(
(layer.local_num_experts, ),
2 * layer.intermediate_size_per_partition,
device=device,
dtype=torch.int64)
def select_gemm_impl( def select_gemm_impl(
self, self,
prepare_finalize: FusedMoEPrepareAndFinalize, prepare_finalize: FusedMoEPrepareAndFinalize,
@ -693,6 +712,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
moe.in_dtype, moe.in_dtype,
self.input_quant.strategy == QuantizationStrategy.TOKEN, self.input_quant.strategy == QuantizationStrategy.TOKEN,
self.weight_quant.strategy == QuantizationStrategy.CHANNEL, self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
ab_strides1=self.ab_strides1_c_strides2,
ab_strides2=self.ab_strides2,
c_strides1=self.c_strides1,
c_strides2=self.ab_strides1_c_strides2,
) )
else: else:
logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__) logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__)
@ -700,6 +723,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
moe.in_dtype, moe.in_dtype,
self.input_quant.strategy == QuantizationStrategy.TOKEN, self.input_quant.strategy == QuantizationStrategy.TOKEN,
self.weight_quant.strategy == QuantizationStrategy.CHANNEL, self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
ab_strides1=self.ab_strides1_c_strides2,
ab_strides2=self.ab_strides2,
c_strides1=self.c_strides1,
c_strides2=self.ab_strides1_c_strides2,
) )
self.disable_expert_map = (num_dispatchers > 1 self.disable_expert_map = (num_dispatchers > 1
@ -822,6 +849,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
expert_map=None if self.disable_expert_map else expert_map, expert_map=None if self.disable_expert_map else expert_map,
w1_scale=layer.w13_weight_scale, w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale, w2_scale=layer.w2_weight_scale,
ab_strides1=self.ab_strides1_c_strides2,
ab_strides2=self.ab_strides2,
c_strides1=self.c_strides1,
c_strides2=self.ab_strides1_c_strides2,
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
) )