[Hardware][NVIDIA] FP4 MoE kernel optimization (#19110)

Signed-off-by: Chiyue Wei <chiyuew@nvidia.com>
Co-authored-by: Chiyue Wei <chiyuew@nvidia.com>
This commit is contained in:
Chiyue Wei 2025-06-05 09:48:26 -07:00 committed by GitHub
parent ec89524f50
commit 61059bee40
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 165 additions and 38 deletions

View File

@ -91,7 +91,7 @@ def bench_run(
score = torch.randn((m, num_experts), device=device, dtype=dtype)
topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False)
topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False)
quant_blocksize = 16
w1_blockscale = torch.empty(

View File

@ -30,4 +30,8 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
int64_t BLOCK_SIZE_K, int64_t bit);
#endif
bool moe_permute_unpermute_supported();
bool moe_permute_unpermute_supported();
void shuffle_rows(const torch::Tensor& input_tensor,
const torch::Tensor& dst2src_map,
torch::Tensor& output_tensor);

View File

@ -130,6 +130,62 @@ void moe_unpermute(
});
}
template <typename T>
__global__ void shuffleInputRowsKernel(const T* input,
const int32_t* dst2src_map, T* output,
int64_t num_src_rows,
int64_t num_dst_rows, int64_t num_cols) {
int64_t dest_row_idx = blockIdx.x;
int64_t const source_row_idx = dst2src_map[dest_row_idx];
if (blockIdx.x < num_dst_rows) {
// Load 128-bits per thread
constexpr int64_t ELEM_PER_THREAD = 128 / sizeof(T) / 8;
using DataElem = cutlass::Array<T, ELEM_PER_THREAD>;
// Duplicate and permute rows
auto const* source_row_ptr =
reinterpret_cast<DataElem const*>(input + source_row_idx * num_cols);
auto* dest_row_ptr =
reinterpret_cast<DataElem*>(output + dest_row_idx * num_cols);
int64_t const start_offset = threadIdx.x;
int64_t const stride = blockDim.x;
int64_t const num_elems_in_col = num_cols / ELEM_PER_THREAD;
for (int elem_index = start_offset; elem_index < num_elems_in_col;
elem_index += stride) {
dest_row_ptr[elem_index] = source_row_ptr[elem_index];
}
}
}
void shuffle_rows(const torch::Tensor& input_tensor,
const torch::Tensor& dst2src_map,
torch::Tensor& output_tensor) {
TORCH_CHECK(input_tensor.scalar_type() == output_tensor.scalar_type(),
"Input and output tensors must have the same data type");
auto stream = at::cuda::getCurrentCUDAStream().stream();
int64_t const blocks = output_tensor.size(0);
int64_t const threads = 256;
int64_t const num_dest_rows = output_tensor.size(0);
int64_t const num_src_rows = input_tensor.size(0);
int64_t const num_cols = input_tensor.size(1);
TORCH_CHECK(!(num_cols % (128 / sizeof(input_tensor.scalar_type()) / 8)),
"num_cols must be divisible by 128 / "
"sizeof(input_tensor.scalar_type()) / 8");
MOE_DISPATCH(input_tensor.scalar_type(), [&] {
shuffleInputRowsKernel<scalar_t><<<blocks, threads, 0, stream>>>(
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
dst2src_map.data_ptr<int32_t>(),
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
num_dest_rows, num_cols);
});
}
#else
void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights,

View File

@ -14,12 +14,13 @@
__VA_ARGS__(); \
break; \
}
#define MOE_DISPATCH_FLOAT_CASE(...) \
MOE_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__)
#define MOE_DISPATCH_FLOAT_CASE(...) \
MOE_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
#define MOE_DISPATCH(TYPE, ...) \
MOE_SWITCH(TYPE, MOE_DISPATCH_FLOAT_CASE(__VA_ARGS__))
@ -39,6 +40,11 @@ template <>
struct ScalarType2CudaType<at::ScalarType::BFloat16> {
using type = __nv_bfloat16;
};
// uint8 for packed fp4
template <>
struct ScalarType2CudaType<at::ScalarType::Byte> {
using type = uint8_t;
};
// #if __CUDA_ARCH__ >= 890
// fp8

View File

@ -81,6 +81,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m.def("moe_permute_unpermute_supported() -> bool");
m.impl("moe_permute_unpermute_supported", &moe_permute_unpermute_supported);
// Row shuffle for MoE
m.def(
"shuffle_rows(Tensor input_tensor, Tensor dst2src_map, Tensor! "
"output_tensor) -> ()");
m.impl("shuffle_rows", torch::kCUDA, &shuffle_rows);
#endif
}

View File

@ -248,7 +248,8 @@ void get_cutlass_moe_mm_data(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k);
const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets);
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,

View File

@ -45,6 +45,23 @@ __global__ void compute_expert_offsets(
}
}
__global__ void compute_expert_blockscale_offsets(
const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets,
int32_t* blockscale_offsets, int32_t* atomic_buffer,
const int num_experts) {
int32_t tot_offset = 0;
int32_t tot_offset_round = 0;
expert_offsets[0] = 0;
blockscale_offsets[0] = 0;
for (int i = 0; i < num_experts; ++i) {
atomic_buffer[i] = tot_offset;
tot_offset += problem_sizes1[i * 3];
expert_offsets[i + 1] = tot_offset;
tot_offset_round += (problem_sizes1[i * 3] + (128 - 1)) / 128 * 128;
blockscale_offsets[i + 1] = tot_offset_round;
}
}
__global__ void compute_arg_sorts(const int* __restrict__ topk_ids,
const int32_t* __restrict__ expert_offsets,
int32_t* input_permutation,
@ -77,7 +94,8 @@ void get_cutlass_moe_mm_data_caller(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k) {
const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets) {
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
auto options_int32 =
torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
@ -89,10 +107,18 @@ void get_cutlass_moe_mm_data_caller(
static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n, k);
compute_expert_offsets<<<1, 1, 0, stream>>>(
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
if (blockscale_offsets.has_value()) {
compute_expert_blockscale_offsets<<<1, 1, 0, stream>>>(
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(blockscale_offsets.value().data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
} else {
compute_expert_offsets<<<1, 1, 0, stream>>>(
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
}
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<const int32_t*>(expert_offsets.data_ptr()),

View File

@ -54,7 +54,8 @@ void get_cutlass_moe_mm_data_caller(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k);
const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets);
#endif
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
@ -224,7 +225,8 @@ void get_cutlass_moe_mm_data(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k) {
const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets) {
// This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for.
int32_t version_num = get_sm_version_num();
@ -232,7 +234,8 @@ void get_cutlass_moe_mm_data(
(defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM90)
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
problem_sizes2, input_permutation,
output_permutation, num_experts, n, k);
output_permutation, num_experts, n, k,
blockscale_offsets);
return;
#endif
TORCH_CHECK_NOT_IMPLEMENTED(

View File

@ -450,7 +450,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
" Tensor! input_permutation, "
" Tensor! output_permutation, int num_experts, "
" int n, int k) -> ()",
" int n, int k, Tensor? blockscale_offsets) -> ()",
{stride_tag});
ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data);

View File

@ -80,7 +80,10 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
w2[expert], w2_gs[expert])
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False)
topk_weights, topk_ids, _ = fused_topk(a,
score,
topk,
renormalize=False)
a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)

View File

@ -845,11 +845,16 @@ def cutlass_scaled_sparse_mm(
return out
def get_cutlass_moe_mm_data(
topk_ids: torch.Tensor, expert_offsets: torch.Tensor,
problem_sizes1: torch.Tensor, problem_sizes2: torch.Tensor,
input_permutation: torch.Tensor, output_permutation: torch.Tensor,
num_experts: int, n: int, k: int):
def get_cutlass_moe_mm_data(topk_ids: torch.Tensor,
expert_offsets: torch.Tensor,
problem_sizes1: torch.Tensor,
problem_sizes2: torch.Tensor,
input_permutation: torch.Tensor,
output_permutation: torch.Tensor,
num_experts: int,
n: int,
k: int,
blockscale_offsets: Optional[torch.Tensor] = None):
"""
Prepare data necessary to perform CUTLASS grouped matrix multiplications
used in CUTLASS-based fused MoE.
@ -867,12 +872,31 @@ def get_cutlass_moe_mm_data(
before executing the MMs.
- output_permutation: Permutation that must be used to shuffle the output
after executing the MMs.
- blockscale_offsets: Optional argument passed for fp4 moe. Indices that
mark at which block scale index each expert begins
its computation. The number of block scale rows
computed with expert E is blockscale_offsets[E + 1] -
blockscale_offsets[E]
"""
return torch.ops._C.get_cutlass_moe_mm_data(topk_ids, expert_offsets,
problem_sizes1, problem_sizes2,
input_permutation,
output_permutation,
num_experts, n, k)
num_experts, n, k,
blockscale_offsets)
def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor):
"""
Shuffle and expand the input tensor according to the dst2src_map and store the result in output_tensor.
This is used in MoE to permute the input tensor before performing grouped matrix multiplications.
"""
num_tokens_permuted = dst2src_map.shape[0]
output_tensor = torch.empty((num_tokens_permuted, input_tensor.shape[1]),
device=input_tensor.device,
dtype=input_tensor.dtype)
torch.ops._moe_C.shuffle_rows(input_tensor, dst2src_map, output_tensor)
return output_tensor
def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor,
@ -1124,14 +1148,12 @@ def scaled_fp4_experts_quant(
expert_offsets: torch.Tensor,
blockscale_offsets: torch.Tensor,
topk: int,
expert_map: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP4 and return quantized tensor and scale, for
packed MoE Inputs.
Args:
input: The input tensor to be quantized to FP4
expert_map: The expert map tensor
input_tensor: The input tensor to be quantized to FP4
input_global_scale: A scalar scaling factor for the entire tensor.
expert_offsets: The expert offsets tensor
blockscale_offsets: The blockscale offsets tensor
@ -1143,14 +1165,13 @@ def scaled_fp4_experts_quant(
assert input_tensor.ndim == 2, (
f'input.ndim needs to be == 2, but got {input_tensor.ndim}.')
input_tensor = input_tensor[
expert_map] if expert_map is not None else input_tensor
m_numtopk, k = input_tensor.shape
# Control the maximum number of tokens per expert supported by the
# NVFP4 MoE Expert Quantization. This is used to prevent the kernel
# from running out of memory. This value can also be increased to support
# larger models.
MAX_TOKENS_PER_EXPERT = envs.VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE
m_numtopk, k = input_tensor.shape
assert (m_numtopk <= MAX_TOKENS_PER_EXPERT * topk), (
f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT("
f"{MAX_TOKENS_PER_EXPERT})"

View File

@ -333,6 +333,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
num_topk = topk_ids.shape[1]
expert_offsets = torch.empty((e + 1), dtype=torch.int32, device=device)
blockscale_offsets = torch.empty((e + 1), dtype=torch.int32, device=device)
# Problem size: (num_experts, (m,2n,k))
problem_sizes1 = torch.empty((e, 3), dtype=torch.int32, device=device)
# Problem size: (num_experts, (m,n,k))
@ -344,12 +345,10 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
# problem shapes should have [m, n, k]
# Note that problem sizes are based on logical number of elements.
ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1,
problem_sizes2, a_map, c_map, e, n, k)
problem_sizes2, a_map, c_map, e, n, k,
blockscale_offsets)
tokens_per_expert = problem_sizes1[:, 0]
rounded_tokens_per_expert = (tokens_per_expert + (128 - 1)) // 128 * 128
blockscale_offsets = torch.zeros(e + 1, dtype=torch.int32, device=device)
blockscale_offsets[1:] = torch.cumsum(rounded_tokens_per_expert, dim=0)
a = ops.shuffle_rows(a, a_map)
rep_a_fp4, rep_a_blockscale = ops.scaled_fp4_experts_quant(
a,
@ -357,7 +356,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
expert_offsets,
blockscale_offsets,
num_topk,
expert_map=a_map)
)
c1 = ops.cutlass_fp4_moe_mm(rep_a_fp4, w1_fp4, rep_a_blockscale,
w1_blockscale, w1_alphas, problem_sizes1,
@ -378,6 +377,8 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
w2_alphas, problem_sizes2, expert_offsets[:-1],
blockscale_offsets[:-1], out_dtype, device)
del int_fp4, int_blockscale
out = (c2[c_map].view(m, num_topk, k) *
c2 = ops.shuffle_rows(c2, c_map)
out = (c2.view(m, num_topk, k) *
topk_weights.view(m, num_topk, 1).half()).sum(dim=1)
return out.to(dtype=out_dtype)