mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 12:25:01 +08:00
[Hardware][NVIDIA] FP4 MoE kernel optimization (#19110)
Signed-off-by: Chiyue Wei <chiyuew@nvidia.com> Co-authored-by: Chiyue Wei <chiyuew@nvidia.com>
This commit is contained in:
parent
ec89524f50
commit
61059bee40
@ -91,7 +91,7 @@ def bench_run(
|
||||
|
||||
score = torch.randn((m, num_experts), device=device, dtype=dtype)
|
||||
|
||||
topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False)
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False)
|
||||
|
||||
quant_blocksize = 16
|
||||
w1_blockscale = torch.empty(
|
||||
|
||||
@ -30,4 +30,8 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
|
||||
int64_t BLOCK_SIZE_K, int64_t bit);
|
||||
#endif
|
||||
|
||||
bool moe_permute_unpermute_supported();
|
||||
bool moe_permute_unpermute_supported();
|
||||
|
||||
void shuffle_rows(const torch::Tensor& input_tensor,
|
||||
const torch::Tensor& dst2src_map,
|
||||
torch::Tensor& output_tensor);
|
||||
@ -130,6 +130,62 @@ void moe_unpermute(
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void shuffleInputRowsKernel(const T* input,
|
||||
const int32_t* dst2src_map, T* output,
|
||||
int64_t num_src_rows,
|
||||
int64_t num_dst_rows, int64_t num_cols) {
|
||||
int64_t dest_row_idx = blockIdx.x;
|
||||
int64_t const source_row_idx = dst2src_map[dest_row_idx];
|
||||
|
||||
if (blockIdx.x < num_dst_rows) {
|
||||
// Load 128-bits per thread
|
||||
constexpr int64_t ELEM_PER_THREAD = 128 / sizeof(T) / 8;
|
||||
using DataElem = cutlass::Array<T, ELEM_PER_THREAD>;
|
||||
|
||||
// Duplicate and permute rows
|
||||
auto const* source_row_ptr =
|
||||
reinterpret_cast<DataElem const*>(input + source_row_idx * num_cols);
|
||||
auto* dest_row_ptr =
|
||||
reinterpret_cast<DataElem*>(output + dest_row_idx * num_cols);
|
||||
|
||||
int64_t const start_offset = threadIdx.x;
|
||||
int64_t const stride = blockDim.x;
|
||||
int64_t const num_elems_in_col = num_cols / ELEM_PER_THREAD;
|
||||
|
||||
for (int elem_index = start_offset; elem_index < num_elems_in_col;
|
||||
elem_index += stride) {
|
||||
dest_row_ptr[elem_index] = source_row_ptr[elem_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void shuffle_rows(const torch::Tensor& input_tensor,
|
||||
const torch::Tensor& dst2src_map,
|
||||
torch::Tensor& output_tensor) {
|
||||
TORCH_CHECK(input_tensor.scalar_type() == output_tensor.scalar_type(),
|
||||
"Input and output tensors must have the same data type");
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
int64_t const blocks = output_tensor.size(0);
|
||||
int64_t const threads = 256;
|
||||
int64_t const num_dest_rows = output_tensor.size(0);
|
||||
int64_t const num_src_rows = input_tensor.size(0);
|
||||
int64_t const num_cols = input_tensor.size(1);
|
||||
|
||||
TORCH_CHECK(!(num_cols % (128 / sizeof(input_tensor.scalar_type()) / 8)),
|
||||
"num_cols must be divisible by 128 / "
|
||||
"sizeof(input_tensor.scalar_type()) / 8");
|
||||
|
||||
MOE_DISPATCH(input_tensor.scalar_type(), [&] {
|
||||
shuffleInputRowsKernel<scalar_t><<<blocks, threads, 0, stream>>>(
|
||||
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
|
||||
dst2src_map.data_ptr<int32_t>(),
|
||||
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
|
||||
num_dest_rows, num_cols);
|
||||
});
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights,
|
||||
|
||||
@ -14,12 +14,13 @@
|
||||
__VA_ARGS__(); \
|
||||
break; \
|
||||
}
|
||||
#define MOE_DISPATCH_FLOAT_CASE(...) \
|
||||
MOE_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
MOE_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||
MOE_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||
MOE_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__) \
|
||||
MOE_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__)
|
||||
#define MOE_DISPATCH_FLOAT_CASE(...) \
|
||||
MOE_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
MOE_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||
MOE_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||
MOE_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__) \
|
||||
MOE_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
|
||||
MOE_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
|
||||
|
||||
#define MOE_DISPATCH(TYPE, ...) \
|
||||
MOE_SWITCH(TYPE, MOE_DISPATCH_FLOAT_CASE(__VA_ARGS__))
|
||||
@ -39,6 +40,11 @@ template <>
|
||||
struct ScalarType2CudaType<at::ScalarType::BFloat16> {
|
||||
using type = __nv_bfloat16;
|
||||
};
|
||||
// uint8 for packed fp4
|
||||
template <>
|
||||
struct ScalarType2CudaType<at::ScalarType::Byte> {
|
||||
using type = uint8_t;
|
||||
};
|
||||
|
||||
// #if __CUDA_ARCH__ >= 890
|
||||
// fp8
|
||||
|
||||
@ -81,6 +81,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("moe_permute_unpermute_supported() -> bool");
|
||||
m.impl("moe_permute_unpermute_supported", &moe_permute_unpermute_supported);
|
||||
|
||||
// Row shuffle for MoE
|
||||
m.def(
|
||||
"shuffle_rows(Tensor input_tensor, Tensor dst2src_map, Tensor! "
|
||||
"output_tensor) -> ()");
|
||||
m.impl("shuffle_rows", torch::kCUDA, &shuffle_rows);
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@ -248,7 +248,8 @@ void get_cutlass_moe_mm_data(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
||||
const int64_t num_experts, const int64_t n, const int64_t k);
|
||||
const int64_t num_experts, const int64_t n, const int64_t k,
|
||||
const std::optional<torch::Tensor>& blockscale_offsets);
|
||||
|
||||
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
|
||||
@ -45,6 +45,23 @@ __global__ void compute_expert_offsets(
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void compute_expert_blockscale_offsets(
|
||||
const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets,
|
||||
int32_t* blockscale_offsets, int32_t* atomic_buffer,
|
||||
const int num_experts) {
|
||||
int32_t tot_offset = 0;
|
||||
int32_t tot_offset_round = 0;
|
||||
expert_offsets[0] = 0;
|
||||
blockscale_offsets[0] = 0;
|
||||
for (int i = 0; i < num_experts; ++i) {
|
||||
atomic_buffer[i] = tot_offset;
|
||||
tot_offset += problem_sizes1[i * 3];
|
||||
expert_offsets[i + 1] = tot_offset;
|
||||
tot_offset_round += (problem_sizes1[i * 3] + (128 - 1)) / 128 * 128;
|
||||
blockscale_offsets[i + 1] = tot_offset_round;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void compute_arg_sorts(const int* __restrict__ topk_ids,
|
||||
const int32_t* __restrict__ expert_offsets,
|
||||
int32_t* input_permutation,
|
||||
@ -77,7 +94,8 @@ void get_cutlass_moe_mm_data_caller(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
||||
const int64_t num_experts, const int64_t n, const int64_t k) {
|
||||
const int64_t num_experts, const int64_t n, const int64_t k,
|
||||
const std::optional<torch::Tensor>& blockscale_offsets) {
|
||||
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
|
||||
auto options_int32 =
|
||||
torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
|
||||
@ -89,10 +107,18 @@ void get_cutlass_moe_mm_data_caller(
|
||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n, k);
|
||||
compute_expert_offsets<<<1, 1, 0, stream>>>(
|
||||
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
|
||||
if (blockscale_offsets.has_value()) {
|
||||
compute_expert_blockscale_offsets<<<1, 1, 0, stream>>>(
|
||||
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
||||
static_cast<int32_t*>(blockscale_offsets.value().data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
|
||||
} else {
|
||||
compute_expert_offsets<<<1, 1, 0, stream>>>(
|
||||
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
|
||||
}
|
||||
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
|
||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<const int32_t*>(expert_offsets.data_ptr()),
|
||||
|
||||
@ -54,7 +54,8 @@ void get_cutlass_moe_mm_data_caller(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
||||
const int64_t num_experts, const int64_t n, const int64_t k);
|
||||
const int64_t num_experts, const int64_t n, const int64_t k,
|
||||
const std::optional<torch::Tensor>& blockscale_offsets);
|
||||
#endif
|
||||
|
||||
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
|
||||
@ -224,7 +225,8 @@ void get_cutlass_moe_mm_data(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
||||
const int64_t num_experts, const int64_t n, const int64_t k) {
|
||||
const int64_t num_experts, const int64_t n, const int64_t k,
|
||||
const std::optional<torch::Tensor>& blockscale_offsets) {
|
||||
// This function currently gets compiled only if we have a valid cutlass moe
|
||||
// mm to run it for.
|
||||
int32_t version_num = get_sm_version_num();
|
||||
@ -232,7 +234,8 @@ void get_cutlass_moe_mm_data(
|
||||
(defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM90)
|
||||
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
|
||||
problem_sizes2, input_permutation,
|
||||
output_permutation, num_experts, n, k);
|
||||
output_permutation, num_experts, n, k,
|
||||
blockscale_offsets);
|
||||
return;
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
|
||||
@ -450,7 +450,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
|
||||
" Tensor! input_permutation, "
|
||||
" Tensor! output_permutation, int num_experts, "
|
||||
" int n, int k) -> ()",
|
||||
" int n, int k, Tensor? blockscale_offsets) -> ()",
|
||||
{stride_tag});
|
||||
ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data);
|
||||
|
||||
|
||||
@ -80,7 +80,10 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
|
||||
w2[expert], w2_gs[expert])
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False)
|
||||
topk_weights, topk_ids, _ = fused_topk(a,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False)
|
||||
|
||||
a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
|
||||
a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
|
||||
|
||||
@ -845,11 +845,16 @@ def cutlass_scaled_sparse_mm(
|
||||
return out
|
||||
|
||||
|
||||
def get_cutlass_moe_mm_data(
|
||||
topk_ids: torch.Tensor, expert_offsets: torch.Tensor,
|
||||
problem_sizes1: torch.Tensor, problem_sizes2: torch.Tensor,
|
||||
input_permutation: torch.Tensor, output_permutation: torch.Tensor,
|
||||
num_experts: int, n: int, k: int):
|
||||
def get_cutlass_moe_mm_data(topk_ids: torch.Tensor,
|
||||
expert_offsets: torch.Tensor,
|
||||
problem_sizes1: torch.Tensor,
|
||||
problem_sizes2: torch.Tensor,
|
||||
input_permutation: torch.Tensor,
|
||||
output_permutation: torch.Tensor,
|
||||
num_experts: int,
|
||||
n: int,
|
||||
k: int,
|
||||
blockscale_offsets: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
Prepare data necessary to perform CUTLASS grouped matrix multiplications
|
||||
used in CUTLASS-based fused MoE.
|
||||
@ -867,12 +872,31 @@ def get_cutlass_moe_mm_data(
|
||||
before executing the MMs.
|
||||
- output_permutation: Permutation that must be used to shuffle the output
|
||||
after executing the MMs.
|
||||
- blockscale_offsets: Optional argument passed for fp4 moe. Indices that
|
||||
mark at which block scale index each expert begins
|
||||
its computation. The number of block scale rows
|
||||
computed with expert E is blockscale_offsets[E + 1] -
|
||||
blockscale_offsets[E]
|
||||
"""
|
||||
return torch.ops._C.get_cutlass_moe_mm_data(topk_ids, expert_offsets,
|
||||
problem_sizes1, problem_sizes2,
|
||||
input_permutation,
|
||||
output_permutation,
|
||||
num_experts, n, k)
|
||||
num_experts, n, k,
|
||||
blockscale_offsets)
|
||||
|
||||
|
||||
def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor):
|
||||
"""
|
||||
Shuffle and expand the input tensor according to the dst2src_map and store the result in output_tensor.
|
||||
This is used in MoE to permute the input tensor before performing grouped matrix multiplications.
|
||||
"""
|
||||
num_tokens_permuted = dst2src_map.shape[0]
|
||||
output_tensor = torch.empty((num_tokens_permuted, input_tensor.shape[1]),
|
||||
device=input_tensor.device,
|
||||
dtype=input_tensor.dtype)
|
||||
torch.ops._moe_C.shuffle_rows(input_tensor, dst2src_map, output_tensor)
|
||||
return output_tensor
|
||||
|
||||
|
||||
def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor,
|
||||
@ -1124,14 +1148,12 @@ def scaled_fp4_experts_quant(
|
||||
expert_offsets: torch.Tensor,
|
||||
blockscale_offsets: torch.Tensor,
|
||||
topk: int,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Quantize input tensor to FP4 and return quantized tensor and scale, for
|
||||
packed MoE Inputs.
|
||||
Args:
|
||||
input: The input tensor to be quantized to FP4
|
||||
expert_map: The expert map tensor
|
||||
input_tensor: The input tensor to be quantized to FP4
|
||||
input_global_scale: A scalar scaling factor for the entire tensor.
|
||||
expert_offsets: The expert offsets tensor
|
||||
blockscale_offsets: The blockscale offsets tensor
|
||||
@ -1143,14 +1165,13 @@ def scaled_fp4_experts_quant(
|
||||
assert input_tensor.ndim == 2, (
|
||||
f'input.ndim needs to be == 2, but got {input_tensor.ndim}.')
|
||||
|
||||
input_tensor = input_tensor[
|
||||
expert_map] if expert_map is not None else input_tensor
|
||||
m_numtopk, k = input_tensor.shape
|
||||
# Control the maximum number of tokens per expert supported by the
|
||||
# NVFP4 MoE Expert Quantization. This is used to prevent the kernel
|
||||
# from running out of memory. This value can also be increased to support
|
||||
# larger models.
|
||||
MAX_TOKENS_PER_EXPERT = envs.VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE
|
||||
m_numtopk, k = input_tensor.shape
|
||||
|
||||
assert (m_numtopk <= MAX_TOKENS_PER_EXPERT * topk), (
|
||||
f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT("
|
||||
f"{MAX_TOKENS_PER_EXPERT})"
|
||||
|
||||
@ -333,6 +333,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
|
||||
num_topk = topk_ids.shape[1]
|
||||
|
||||
expert_offsets = torch.empty((e + 1), dtype=torch.int32, device=device)
|
||||
blockscale_offsets = torch.empty((e + 1), dtype=torch.int32, device=device)
|
||||
# Problem size: (num_experts, (m,2n,k))
|
||||
problem_sizes1 = torch.empty((e, 3), dtype=torch.int32, device=device)
|
||||
# Problem size: (num_experts, (m,n,k))
|
||||
@ -344,12 +345,10 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
|
||||
# problem shapes should have [m, n, k]
|
||||
# Note that problem sizes are based on logical number of elements.
|
||||
ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1,
|
||||
problem_sizes2, a_map, c_map, e, n, k)
|
||||
problem_sizes2, a_map, c_map, e, n, k,
|
||||
blockscale_offsets)
|
||||
|
||||
tokens_per_expert = problem_sizes1[:, 0]
|
||||
rounded_tokens_per_expert = (tokens_per_expert + (128 - 1)) // 128 * 128
|
||||
blockscale_offsets = torch.zeros(e + 1, dtype=torch.int32, device=device)
|
||||
blockscale_offsets[1:] = torch.cumsum(rounded_tokens_per_expert, dim=0)
|
||||
a = ops.shuffle_rows(a, a_map)
|
||||
|
||||
rep_a_fp4, rep_a_blockscale = ops.scaled_fp4_experts_quant(
|
||||
a,
|
||||
@ -357,7 +356,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
|
||||
expert_offsets,
|
||||
blockscale_offsets,
|
||||
num_topk,
|
||||
expert_map=a_map)
|
||||
)
|
||||
|
||||
c1 = ops.cutlass_fp4_moe_mm(rep_a_fp4, w1_fp4, rep_a_blockscale,
|
||||
w1_blockscale, w1_alphas, problem_sizes1,
|
||||
@ -378,6 +377,8 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
|
||||
w2_alphas, problem_sizes2, expert_offsets[:-1],
|
||||
blockscale_offsets[:-1], out_dtype, device)
|
||||
del int_fp4, int_blockscale
|
||||
out = (c2[c_map].view(m, num_topk, k) *
|
||||
|
||||
c2 = ops.shuffle_rows(c2, c_map)
|
||||
out = (c2.view(m, num_topk, k) *
|
||||
topk_weights.view(m, num_topk, 1).half()).sum(dim=1)
|
||||
return out.to(dtype=out_dtype)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user