[Kernel] Integrate CUTLASS MoE kernel with PPLX (#18762)

Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
ElizaWszola 2025-06-07 03:26:11 +02:00 committed by GitHub
parent 6e0cd10f72
commit 84166fee97
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 918 additions and 409 deletions

View File

@ -543,8 +543,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# CUTLASS MoE kernels # CUTLASS MoE kernels
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works # The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works
# on Hopper). get_cutlass_moe_mm_data should only be compiled if it's possible # on Hopper). get_cutlass_(pplx_)moe_mm_data should only be compiled
# to compile MoE kernels that use its output. # if it's possible to compile MoE kernels that use its output.
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}") cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu" set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu"

View File

@ -7,8 +7,8 @@ from benchmark_shapes import WEIGHT_SHAPES_MOE
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
cutlass_moe_fp8,
fused_experts, fused_experts,
fused_topk, fused_topk,
) )
@ -70,18 +70,9 @@ def bench_run(
w1_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32) w1_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32)
w2_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32) w2_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32)
ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
for expert in range(num_experts): for expert in range(num_experts):
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert]) w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert])
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert]) w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert])
w1_q_notransp = w1_q.clone()
w2_q_notransp = w2_q.clone()
w1_q = w1_q.transpose(1, 2)
w2_q = w2_q.transpose(1, 2)
score = torch.randn((m, num_experts), device="cuda", dtype=dtype) score = torch.randn((m, num_experts), device="cuda", dtype=dtype)
@ -122,10 +113,6 @@ def bench_run(
w2_scale: torch.Tensor, w2_scale: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
ab_strides1: torch.Tensor,
c_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides2: torch.Tensor,
num_repeats: int, num_repeats: int,
): ):
for _ in range(num_repeats): for _ in range(num_repeats):
@ -133,14 +120,10 @@ def bench_run(
a, a,
w1, w1,
w2, w2,
w1_scale,
w2_scale,
topk_weights, topk_weights,
topk_ids, topk_ids,
ab_strides1, w1_scale,
c_strides1, w2_scale,
ab_strides2,
c_strides2,
a1_scale=a_scale, a1_scale=a_scale,
) )
@ -153,10 +136,6 @@ def bench_run(
w2_scale: torch.Tensor, w2_scale: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
ab_strides1: torch.Tensor,
c_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides2: torch.Tensor,
): ):
with set_current_vllm_config( with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
@ -165,14 +144,10 @@ def bench_run(
a, a,
w1_q, w1_q,
w2_q, w2_q,
w1_scale,
w2_scale,
topk_weights, topk_weights,
topk_ids, topk_ids,
ab_strides1, w1_scale,
c_strides1, w2_scale,
ab_strides2,
c_strides2,
a1_scale=a_scale, a1_scale=a_scale,
) )
@ -218,10 +193,6 @@ def bench_run(
w2_scale, w2_scale,
topk_weights, topk_weights,
topk_ids, topk_ids,
ab_strides1,
c_strides1,
ab_strides2,
c_strides2,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
@ -230,8 +201,8 @@ def bench_run(
with torch.cuda.graph(triton_graph, stream=triton_stream): with torch.cuda.graph(triton_graph, stream=triton_stream):
run_triton_from_graph( run_triton_from_graph(
a, a,
w1_q_notransp, w1_q,
w2_q_notransp, w2_q,
topk_weights, topk_weights,
topk_ids, topk_ids,
w1_scale, w1_scale,
@ -250,18 +221,12 @@ def bench_run(
"w2": w2, "w2": w2,
"score": score, "score": score,
"topk": topk, "topk": topk,
"w1_q_notransp": w1_q_notransp,
"w2_q_notransp": w2_q_notransp,
# Cutlass params # Cutlass params
"a_scale": a_scale, "a_scale": a_scale,
"w1_q": w1_q, "w1_q": w1_q,
"w2_q": w2_q, "w2_q": w2_q,
"w1_scale": w1_scale, "w1_scale": w1_scale,
"w2_scale": w2_scale, "w2_scale": w2_scale,
"ab_strides1": ab_strides1,
"c_strides1": c_strides1,
"ab_strides2": ab_strides2,
"c_strides2": c_strides2,
# cuda graph params # cuda graph params
"cutlass_graph": cutlass_graph, "cutlass_graph": cutlass_graph,
"triton_graph": triton_graph, "triton_graph": triton_graph,
@ -279,8 +244,8 @@ def bench_run(
# Warmup # Warmup
run_triton_moe( run_triton_moe(
a, a,
w1_q_notransp, w1_q,
w2_q_notransp, w2_q,
topk_weights, topk_weights,
topk_ids, topk_ids,
w1_scale, w1_scale,
@ -291,7 +256,7 @@ def bench_run(
results.append( results.append(
benchmark.Timer( benchmark.Timer(
stmt="run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501 stmt="run_triton_moe(a, w1_q, w2_q, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501
globals=globals, globals=globals,
label=label, label=label,
sub_label=sub_label, sub_label=sub_label,
@ -322,16 +287,12 @@ def bench_run(
w2_scale, w2_scale,
topk_weights, topk_weights,
topk_ids, topk_ids,
ab_strides1,
c_strides1,
ab_strides2,
c_strides2,
num_warmup, num_warmup,
) )
results.append( results.append(
benchmark.Timer( benchmark.Timer(
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501 stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, num_runs)", # noqa: E501
globals=globals, globals=globals,
label=label, label=label,
sub_label=sub_label, sub_label=sub_label,

View File

@ -236,7 +236,8 @@ void cutlass_moe_mm(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides); torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch);
void cutlass_fp4_group_mm( void cutlass_fp4_group_mm(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
@ -251,6 +252,14 @@ void get_cutlass_moe_mm_data(
const int64_t num_experts, const int64_t n, const int64_t k, const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets); const std::optional<torch::Tensor>& blockscale_offsets);
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
const torch::Tensor& expert_num_tokens,
const int64_t num_local_experts,
const int64_t padded_m, const int64_t n,
const int64_t k);
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& b,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,

View File

@ -84,7 +84,8 @@ void run_cutlass_moe_mm_sm90(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides) { torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided."); TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided."); TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided."); TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
@ -113,19 +114,23 @@ void run_cutlass_moe_mm_sm90(
if (n >= 8192) { if (n >= 8192) {
cutlass_group_gemm_caller<Cutlass3xGemmN8192>( cutlass_group_gemm_caller<Cutlass3xGemmN8192>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides); problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else if (k >= 8192) { } else if (k >= 8192) {
cutlass_group_gemm_caller<Cutlass3xGemmK8192>( cutlass_group_gemm_caller<Cutlass3xGemmK8192>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides); problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else if (m <= 16) { } else if (m <= 16) {
cutlass_group_gemm_caller<Cutlass3xGemmM16>( cutlass_group_gemm_caller<Cutlass3xGemmM16>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides); problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else { } else {
cutlass_group_gemm_caller<Cutlass3xGemmDefault>( cutlass_group_gemm_caller<Cutlass3xGemmDefault>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides); problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} }
} }
@ -134,15 +139,18 @@ void dispatch_moe_mm_sm90(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides) { torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
if (out_tensors.dtype() == torch::kBFloat16) { if (out_tensors.dtype() == torch::kBFloat16) {
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::bfloat16_t>( run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides); problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else { } else {
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::half_t>( run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::half_t>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides); problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} }
} }
@ -153,8 +161,9 @@ void cutlass_moe_mm_sm90(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides) { torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
dispatch_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales, dispatch_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides, expert_offsets, problem_sizes, a_strides, b_strides,
c_strides); c_strides, per_act_token, per_out_ch);
} }

View File

@ -76,7 +76,8 @@ void cutlass_group_gemm_caller(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides) { torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
using ElementAB = typename Gemm::ElementAB; using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD; using ElementD = typename Gemm::ElementD;
@ -84,9 +85,6 @@ void cutlass_group_gemm_caller(
int k_size = a_tensors.size(1); int k_size = a_tensors.size(1);
int n_size = out_tensors.size(1); int n_size = out_tensors.size(1);
bool per_act_token = a_scales.numel() != 1;
bool per_out_ch = b_scales.numel() != num_experts;
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
auto options_int = auto options_int =

View File

@ -7,7 +7,7 @@
constexpr uint64_t THREADS_PER_EXPERT = 512; constexpr uint64_t THREADS_PER_EXPERT = 512;
__global__ void compute_problem_sizes(const int* __restrict__ topk_ids, __global__ void compute_problem_sizes(const uint32_t* __restrict__ topk_ids,
int32_t* problem_sizes1, int32_t* problem_sizes1,
int32_t* problem_sizes2, int32_t* problem_sizes2,
int32_t* atomic_buffer, int32_t* atomic_buffer,
@ -62,7 +62,7 @@ __global__ void compute_expert_blockscale_offsets(
} }
} }
__global__ void compute_arg_sorts(const int* __restrict__ topk_ids, __global__ void compute_arg_sorts(const uint32_t* __restrict__ topk_ids,
const int32_t* __restrict__ expert_offsets, const int32_t* __restrict__ expert_offsets,
int32_t* input_permutation, int32_t* input_permutation,
int32_t* output_permutation, int32_t* output_permutation,
@ -103,7 +103,7 @@ void get_cutlass_moe_mm_data_caller(
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel()); int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
compute_problem_sizes<<<num_experts, num_threads, 0, stream>>>( compute_problem_sizes<<<num_experts, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()), static_cast<const uint32_t*>(topk_ids.data_ptr()),
static_cast<int32_t*>(problem_sizes1.data_ptr()), static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()), static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n, k); static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n, k);
@ -120,10 +120,44 @@ void get_cutlass_moe_mm_data_caller(
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts); static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
} }
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>( compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()), static_cast<const uint32_t*>(topk_ids.data_ptr()),
static_cast<const int32_t*>(expert_offsets.data_ptr()), static_cast<const int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(input_permutation.data_ptr()), static_cast<int32_t*>(input_permutation.data_ptr()),
static_cast<int32_t*>(output_permutation.data_ptr()), static_cast<int32_t*>(output_permutation.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(),
topk_ids.size(1)); topk_ids.size(1));
} }
__global__ void compute_pplx_data(int32_t* expert_offsets,
int32_t* problem_sizes1,
int32_t* problem_sizes2,
const int32_t* __restrict__ expert_num_tokens,
const int padded_m, const int n,
const int k) {
int expert_idx = threadIdx.x;
expert_offsets[expert_idx] = expert_idx * padded_m;
problem_sizes1[expert_idx * 3] = expert_num_tokens[expert_idx];
problem_sizes1[expert_idx * 3 + 1] = 2 * n;
problem_sizes1[expert_idx * 3 + 2] = k;
problem_sizes2[expert_idx * 3] = expert_num_tokens[expert_idx];
problem_sizes2[expert_idx * 3 + 1] = k;
problem_sizes2[expert_idx * 3 + 2] = n;
}
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
const torch::Tensor& expert_num_tokens,
const int64_t num_local_experts,
const int64_t padded_m,
const int64_t n, const int64_t k) {
auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index());
compute_pplx_data<<<1, num_local_experts, 0, stream>>>(
static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
k);
}

View File

@ -36,7 +36,8 @@ void cutlass_moe_mm_sm90(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides); torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch);
#endif #endif
@ -56,6 +57,14 @@ void get_cutlass_moe_mm_data_caller(
torch::Tensor& input_permutation, torch::Tensor& output_permutation, torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k, const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets); const std::optional<torch::Tensor>& blockscale_offsets);
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
const torch::Tensor& expert_num_tokens,
const int64_t num_local_experts,
const int64_t padded_m,
const int64_t n, const int64_t k);
#endif #endif
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
@ -207,12 +216,13 @@ void cutlass_moe_mm(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides) { torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
int32_t version_num = get_sm_version_num(); int32_t version_num = get_sm_version_num();
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90 #if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales, cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides, expert_offsets, problem_sizes, a_strides, b_strides,
c_strides); c_strides, per_act_token, per_out_ch);
return; return;
#endif #endif
TORCH_CHECK_NOT_IMPLEMENTED( TORCH_CHECK_NOT_IMPLEMENTED(
@ -245,6 +255,29 @@ void get_cutlass_moe_mm_data(
version_num, ". Required capability: 90"); version_num, ". Required capability: 90");
} }
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
const torch::Tensor& expert_num_tokens,
const int64_t num_local_experts,
const int64_t padded_m, const int64_t n,
const int64_t k) {
// This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for.
int32_t version_num = get_sm_version_num();
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
get_cutlass_pplx_moe_mm_data_caller(expert_offsets, problem_sizes1,
problem_sizes2, expert_num_tokens,
num_local_experts, padded_m, n, k);
return;
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled get_cutlass_pplx_moe_mm_data: no cutlass_scaled_mm kernel "
"for CUDA device capability: ",
version_num, ". Required capability: 90");
}
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& b,
torch::Tensor const& a_scales, torch::Tensor const& a_scales,

View File

@ -435,7 +435,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"cutlass_moe_mm(Tensor! out_tensors, Tensor a_tensors, Tensor b_tensors, " "cutlass_moe_mm(Tensor! out_tensors, Tensor a_tensors, Tensor b_tensors, "
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, " " Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
" Tensor problem_sizes, Tensor a_strides, " " Tensor problem_sizes, Tensor a_strides, "
" Tensor b_strides, Tensor c_strides) -> ()", " Tensor b_strides, Tensor c_strides, bool per_act_token, "
" bool per_out_ch) -> ()",
{stride_tag}); {stride_tag});
ops.impl("cutlass_moe_mm", torch::kCUDA, &cutlass_moe_mm); ops.impl("cutlass_moe_mm", torch::kCUDA, &cutlass_moe_mm);
@ -454,6 +455,22 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
{stride_tag}); {stride_tag});
ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data); ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data);
// A function that computes data required to run fused MoE with w8a8 grouped
// GEMM and PPLX. It takes expert_num_tokens and non_zero_expert_idxs
// as an input, and computes expert_offsets (token start indices of each
// expert). In addition to this, it computes problem sizes for each expert's
// multiplication used by the two mms called from fused MoE operation.
ops.def(
"get_cutlass_pplx_moe_mm_data(Tensor! expert_offsets, "
" Tensor! problem_sizes1, "
" Tensor! problem_sizes2, "
" Tensor expert_num_tokens, "
" int num_local_experts, int padded_m, "
" int n, int k) -> ()",
{stride_tag});
ops.impl("get_cutlass_pplx_moe_mm_data", torch::kCUDA,
&get_cutlass_pplx_moe_mm_data);
// Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3) // Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3)
ops.def( ops.def(
"cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> " "cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> "

View File

@ -193,14 +193,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
kwargs = { kwargs = {
'a': moe_tensors.a, 'a': moe_tensors.a,
'w1_q': moe_tensors.w1_q.transpose(1, 2), # type: ignore[union-attr] 'w1_q': moe_tensors.w1_q, # type: ignore[union-attr]
'w2_q': moe_tensors.w2_q.transpose(1, 2), # type: ignore[union-attr] 'w2_q': moe_tensors.w2_q, # type: ignore[union-attr]
'topk_weights': topk_weights, 'topk_weights': topk_weights,
'topk_ids': topk_ids, 'topk_ids': topk_ids,
'ab_strides1': moe_tensors.ab_strides1,
'c_strides1': moe_tensors.c_strides1,
'ab_strides2': moe_tensors.ab_strides2,
'c_strides2': moe_tensors.c_strides2,
'w1_scale': moe_tensors.w1_scale, 'w1_scale': moe_tensors.w1_scale,
'w2_scale': moe_tensors.w2_scale, 'w2_scale': moe_tensors.w2_scale,
'a1_scale': moe_tensors.a_scale 'a1_scale': moe_tensors.a_scale

View File

@ -0,0 +1,287 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from tests.pplx_utils import ProcessGroupInfo, parallel_launch
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.platforms import current_platform
try:
from pplx_kernels import AllToAll
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
nvshmem_finalize, nvshmem_get_unique_id,
nvshmem_init)
has_pplx = True
except ImportError:
has_pplx = False
requires_pplx = pytest.mark.skipif(
not has_pplx,
reason="Requires PPLX kernels",
)
NUM_EXPERTS = [40, 64]
TOP_KS = [6, 8]
def rank_chunk(num, r, w):
rem = num % w
return (num // w) + (1 if r < rem else 0)
def chunk_by_rank(t, r, w):
num = t.shape[0]
chunk = rank_chunk(num, r, w)
rem = num % w
if rem == 0 or r < rem:
return t[(r * chunk):(r + 1) * chunk].contiguous()
else:
long_chunks = (num // w + 1) * rem
short_chunks = (r - rem) * chunk
start = long_chunks + short_chunks
return t[start:start + chunk].contiguous()
def pplx_cutlass_moe(
pgi: ProcessGroupInfo,
dp_size: int,
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
a1_scale: torch.Tensor,
out_dtype,
per_act_token: bool,
per_out_ch: bool,
):
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize)
assert torch.cuda.current_device() == pgi.local_rank
num_tokens, hidden_dim = a.shape
num_experts = w1.shape[0]
block_size = hidden_dim # TODO support more cases
device = pgi.device
rank = pgi.rank
world_size = pgi.world_size
rank_num_tokens = rank_chunk(num_tokens, rank, world_size)
max_num_tokens = rank_chunk(num_tokens, 0, world_size)
topk = topk_ids.shape[1]
if block_size == hidden_dim:
scale_elems = 4 # hack to circumvent pplx data format requirements
else:
scale_elems = (hidden_dim + block_size - 1) // block_size
ata = AllToAll.internode(
max_num_tokens=max_num_tokens,
num_experts=num_experts,
experts_per_token=topk,
rank=rank,
world_size=pgi.world_size,
dp_size=dp_size,
hidden_dim=hidden_dim,
hidden_dim_bytes=hidden_dim, # because a.dtype.itemsize == 1
hidden_dim_scale_bytes=scale_elems * torch.float32.itemsize,
)
w1 = w1.to(device)
w2 = w2.to(device)
w1_scale = w1_scale.to(device)
w2_scale = w2_scale.to(device)
a1_scale = a1_scale.to(device)
prepare_finalize = PplxPrepareAndFinalize(
ata,
max_num_tokens,
pgi.world_size,
rank,
dp_size,
quant_dtype=torch.float8_e4m3fn,
per_act_token=per_act_token,
)
experts = CutlassExpertsFp8((num_experts + world_size - 1) // world_size,
out_dtype, per_act_token, per_out_ch)
fused_cutlass_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
)
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
chunk_topk_weight = chunk_by_rank(topk_weights, rank,
world_size).to(device)
chunk_topk_ids = chunk_by_rank(topk_ids, rank,
world_size).to(torch.uint32).to(device)
out = fused_cutlass_experts(
a_chunk,
chunk_by_rank(w1, rank, world_size),
chunk_by_rank(w2, rank, world_size),
chunk_topk_weight,
chunk_topk_ids,
global_num_experts=num_experts,
expert_map=None, #TODO
w1_scale=chunk_by_rank(w1_scale, rank, world_size),
w2_scale=chunk_by_rank(w2_scale, rank, world_size),
a1_scale=chunk_by_rank(a1_scale, rank, world_size)
if per_act_token else a1_scale[rank])
torch.cuda.synchronize()
ata.destroy()
return out[:rank_num_tokens]
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
def torch_moe2(a, w1, w2, topk_weight, topk_ids):
M, K = a.shape
topk = topk_ids.shape[1]
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
num_experts = w1.shape[0]
for i in range(num_experts):
mask = (topk_ids == i).view(-1)
if mask.sum():
out[mask] = SiluAndMul()(
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
return (out.view(M, -1, w2.shape[1]) *
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
def _pplx_moe(
pgi: ProcessGroupInfo,
dp_size: int,
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
a1_scale: torch.Tensor,
out_dtype,
a_full: torch.Tensor,
w1_full: torch.Tensor,
w2_full: torch.Tensor,
per_act_token: bool,
per_out_ch: bool,
):
uid = nvshmem_get_unique_id(
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
torch.distributed.broadcast(uid, src=0)
nvshmem_init(uid, pgi.rank, pgi.world_size)
with set_current_vllm_config(vllm_config):
torch_output = torch_moe2(a_full, w1_full, w2_full, topk_weights,
topk_ids)
pplx_output = pplx_cutlass_moe(pgi, dp_size, a, w1, w2, w1_scale,
w2_scale, topk_weights, topk_ids,
a1_scale, out_dtype, per_act_token,
per_out_ch)
torch_output = chunk_by_rank(torch_output, pgi.rank,
pgi.world_size).to(pplx_output.device)
# Uncomment if more debugging is needed
# print("PPLX OUT:", pplx_output)
# print("TORCH OUT:", torch_output)
torch.testing.assert_close(pplx_output, torch_output, atol=0.05, rtol=0)
nvshmem_finalize()
@pytest.mark.parametrize("m", [2, 224])
@pytest.mark.parametrize("n", [3072])
@pytest.mark.parametrize("k", [1536])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]])
@pytest.mark.skipif(
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
current_platform.get_device_capability()),
reason="Grouped gemm is not supported on this GPU type.")
@requires_pplx
def test_cutlass_moe_pplx(
m: int,
n: int,
k: int,
e: int,
topk: int,
per_act_token: bool,
per_out_ch: bool,
world_dp_size: tuple[int, int],
):
current_platform.seed_everything(7)
with set_current_vllm_config(vllm_config):
dtype = torch.half
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10.0
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10.0
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10.0
n_b_scales = 2 * n if per_out_ch else 1
k_b_scales = k if per_out_ch else 1
w1_q = torch.empty((e, 2 * n, k),
device="cuda",
dtype=torch.float8_e4m3fn)
w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn)
w1_scale = torch.empty((e, n_b_scales, 1),
device="cuda",
dtype=torch.float32)
w2_scale = torch.empty((e, k_b_scales, 1),
device="cuda",
dtype=torch.float32)
for expert in range(e):
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
w1[expert], use_per_token_if_dynamic=per_out_ch)
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
w2[expert], use_per_token_if_dynamic=per_out_ch)
w1_d = torch.empty_like(w1)
w2_d = torch.empty_like(w2)
for expert in range(e):
w1_d[expert] = (w1_q[expert].float() * w1_scale[expert]).half()
w2_d[expert] = (w2_q[expert].float() * w2_scale[expert]).half()
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a,
score,
topk,
renormalize=False)
world_size, dp_size = world_dp_size
a_scale1 = torch.randn(
(m if per_act_token else 1, 1), device="cuda",
dtype=torch.float32) / 10.0
if not per_act_token:
a_scale1 = a_scale1.repeat(world_size, 1)
parallel_launch(world_size, _pplx_moe, dp_size, a, w1_q, w2_q,
w1_scale, w2_scale, topk_weights, topk_ids, a_scale1,
dtype, a, w1_d, w2_d, per_act_token, per_out_ch)

View File

@ -4,10 +4,7 @@
Run `pytest tests/kernels/test_pplx_moe.py`. Run `pytest tests/kernels/test_pplx_moe.py`.
""" """
import dataclasses from typing import Optional
import os
import traceback
from typing import Callable, Optional
import pytest import pytest
import torch import torch
@ -21,10 +18,7 @@ try:
except ImportError: except ImportError:
has_pplx = False has_pplx = False
from torch.multiprocessing import ( from tests.pplx_utils import ProcessGroupInfo, parallel_launch
spawn) # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import Concatenate, ParamSpec
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import override_config from vllm.model_executor.layers.fused_moe import override_config
@ -36,6 +30,11 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel) FusedMoEModularKernel)
from vllm.platforms import current_platform from vllm.platforms import current_platform
requires_pplx = pytest.mark.skipif(
not has_pplx,
reason="Requires PPLX kernels",
)
PPLX_PREPARE_COMBOS = [(4, 128, 128), (32, 1024, 512), (64, 1024, 512), PPLX_PREPARE_COMBOS = [(4, 128, 128), (32, 1024, 512), (64, 1024, 512),
(222, 2048, 1024)] (222, 2048, 1024)]
@ -57,122 +56,6 @@ vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128 vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192 vllm_config.scheduler_config.max_model_len = 8192
P = ParamSpec("P")
requires_pplx = pytest.mark.skipif(
not has_pplx,
reason="Requires PPLX kernels",
)
@dataclasses.dataclass
class ProcessGroupInfo:
world_size: int
world_local_size: int
rank: int
node_rank: int
local_rank: int
device: torch.device
def _worker_parallel_launch(
local_rank: int,
world_size: int,
world_local_size: int,
node_rank: int,
init_method: str,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
rank = node_rank * world_local_size + local_rank
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
torch.distributed.init_process_group(
backend="cpu:gloo,cuda:nccl",
init_method=init_method,
rank=rank,
world_size=world_size,
device_id=device,
)
barrier = torch.tensor([rank], device=device)
torch.distributed.all_reduce(barrier)
try:
worker(
ProcessGroupInfo(
world_size=world_size,
world_local_size=world_local_size,
rank=rank,
node_rank=node_rank,
local_rank=local_rank,
device=device,
),
*args,
**kwargs,
)
except Exception as ex:
print(ex)
traceback.print_exc()
raise
finally:
torch.distributed.destroy_process_group()
def parallel_launch(
world_size: int,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
assert not kwargs
spawn(
_worker_parallel_launch,
args=(
world_size,
world_size,
0,
"tcp://localhost:29500",
worker,
) + args,
nprocs=world_size,
join=True,
)
def parallel_launch_from_env(
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
"""
Launches a worker function in parallel across all processes in the current
environment. The environment must have the following variables set:
- WORLD_SIZE: The total number of processes.
- WORLD_LOCAL_SIZE: The number of processes on the current node.
- NODE_RANK: The rank of the current
- MASTER_ADDR: The address of the master process.
- MASTER_PORT: The port of the master process.
"""
assert not kwargs
world_size = int(os.environ["WORLD_SIZE"])
world_local_size = int(os.environ["WORLD_LOCAL_SIZE"])
node_rank = int(os.environ["NODE_RANK"])
assert "MASTER_ADDR" in os.environ
assert "MASTER_PORT" in os.environ
spawn(
_worker_parallel_launch,
args=(
world_size,
world_local_size,
node_rank,
"env://",
worker,
) + args,
nprocs=world_local_size,
join=True,
)
def torch_prepare( def torch_prepare(
a: torch.Tensor, a: torch.Tensor,

View File

@ -632,7 +632,8 @@ def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool,
ops.cutlass_moe_mm(out_tensors_stacked, a_tensors_stacked, ops.cutlass_moe_mm(out_tensors_stacked, a_tensors_stacked,
b_tensors_stacked, a_scales_tensors_stacked, b_tensors_stacked, a_scales_tensors_stacked,
b_scales_tensors_stacked, expert_offsets[:-1], b_scales_tensors_stacked, expert_offsets[:-1],
problem_sizes, ab_strides, ab_strides, c_strides) problem_sizes, ab_strides, ab_strides, c_strides,
per_act_token, per_out_ch)
# Validate each group's result against the baseline # Validate each group's result against the baseline
for g in range(num_experts): for g in range(num_experts):

123
tests/pplx_utils.py Normal file
View File

@ -0,0 +1,123 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import os
import traceback
from typing import Callable
import torch
from torch.multiprocessing import (
spawn) # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import Concatenate, ParamSpec
P = ParamSpec("P")
@dataclasses.dataclass
class ProcessGroupInfo:
world_size: int
world_local_size: int
rank: int
node_rank: int
local_rank: int
device: torch.device
def _worker_parallel_launch(
local_rank: int,
world_size: int,
world_local_size: int,
node_rank: int,
init_method: str,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
rank = node_rank * world_local_size + local_rank
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
torch.distributed.init_process_group(
backend="cpu:gloo,cuda:nccl",
init_method=init_method,
rank=rank,
world_size=world_size,
device_id=device,
)
barrier = torch.tensor([rank], device=device)
torch.distributed.all_reduce(barrier)
try:
worker(
ProcessGroupInfo(
world_size=world_size,
world_local_size=world_local_size,
rank=rank,
node_rank=node_rank,
local_rank=local_rank,
device=device,
),
*args,
**kwargs,
)
except Exception as ex:
print(ex)
traceback.print_exc()
raise
finally:
torch.distributed.destroy_process_group()
def parallel_launch(
world_size: int,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
assert not kwargs
spawn(
_worker_parallel_launch,
args=(
world_size,
world_size,
0,
"tcp://localhost:29500",
worker,
) + args,
nprocs=world_size,
join=True,
)
def parallel_launch_from_env(
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
"""
Launches a worker function in parallel across all processes in the current
environment. The environment must have the following variables set:
- WORLD_SIZE: The total number of processes.
- WORLD_LOCAL_SIZE: The number of processes on the current node.
- NODE_RANK: The rank of the current
- MASTER_ADDR: The address of the master process.
- MASTER_PORT: The port of the master process.
"""
assert not kwargs
world_size = int(os.environ["WORLD_SIZE"])
world_local_size = int(os.environ["WORLD_LOCAL_SIZE"])
node_rank = int(os.environ["NODE_RANK"])
assert "MASTER_ADDR" in os.environ
assert "MASTER_PORT" in os.environ
spawn(
_worker_parallel_launch,
args=(
world_size,
world_local_size,
node_rank,
"env://",
worker,
) + args,
nprocs=world_local_size,
join=True,
)

View File

@ -899,11 +899,36 @@ def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor):
return output_tensor return output_tensor
def get_cutlass_pplx_moe_mm_data(expert_offsets: torch.Tensor,
problem_sizes1: torch.Tensor,
problem_sizes2: torch.Tensor,
expert_num_tokens: torch.Tensor,
num_local_experts: int, padded_m: int, n: int,
k: int):
"""
Prepare data necessary to perform CUTLASS grouped matrix multiplications
used in CUTLASS-based fused MoE.
The function takes in expert_num_tokens (token count per expert) and
non_zero_expert_idxs (consecutive indices of experts with non-zero token
counts) and uses them to compute:
- expert_offsets: Indices that mark at which token index each expert begins
its computation.
- problem_sizes1, problem_sizes2: MxNxK sizes of each expert's
multiplication in two grouped MMs used in
the fused MoE operation.
"""
return torch.ops._C.get_cutlass_pplx_moe_mm_data(
expert_offsets, problem_sizes1, problem_sizes2, expert_num_tokens,
num_local_experts, padded_m, n, k)
def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor,
b_tensors: torch.Tensor, a_scales: torch.Tensor, b_tensors: torch.Tensor, a_scales: torch.Tensor,
b_scales: torch.Tensor, expert_offsets: torch.Tensor, b_scales: torch.Tensor, expert_offsets: torch.Tensor,
problem_sizes: torch.Tensor, a_strides: torch.Tensor, problem_sizes: torch.Tensor, a_strides: torch.Tensor,
b_strides: torch.Tensor, c_strides: torch.Tensor): b_strides: torch.Tensor, c_strides: torch.Tensor,
per_act_token: bool, per_out_ch: bool):
""" """
A single grouped matrix multiplication used in CUTLASS-based fused MoE. A single grouped matrix multiplication used in CUTLASS-based fused MoE.
The function executes fp8-quantized OUT = AB matrix multiplication. The function executes fp8-quantized OUT = AB matrix multiplication.
@ -918,7 +943,7 @@ def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor,
return torch.ops._C.cutlass_moe_mm(out_tensors, a_tensors, b_tensors, return torch.ops._C.cutlass_moe_mm(out_tensors, a_tensors, b_tensors,
a_scales, b_scales, expert_offsets, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, problem_sizes, a_strides, b_strides,
c_strides) c_strides, per_act_token, per_out_ch)
def cutlass_fp4_moe_mm(a_tensors: torch.Tensor, b_tensors: torch.Tensor, def cutlass_fp4_moe_mm(a_tensors: torch.Tensor, b_tensors: torch.Tensor,

View File

@ -39,6 +39,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor, a: torch.Tensor,
aq: torch.Tensor,
M: int, M: int,
N: int, N: int,
K: int, K: int,

View File

@ -67,6 +67,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor, a: torch.Tensor,
aq: torch.Tensor,
M: int, M: int,
N: int, N: int,
K: int, K: int,
@ -78,11 +79,11 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# even if we fall back to triton later, e.g. if expert maps are set. # even if we fall back to triton later, e.g. if expert maps are set.
if self.allow_deep_gemm and self.batched_deep_gemm_experts is not None: if self.allow_deep_gemm and self.batched_deep_gemm_experts is not None:
return self.batched_deep_gemm_experts.workspace_shapes( return self.batched_deep_gemm_experts.workspace_shapes(
a, M, N, K, topk, num_experts) a, aq, M, N, K, topk, num_experts)
else: else:
assert self.batched_triton_experts is not None assert self.batched_triton_experts is not None
return self.batched_triton_experts.workspace_shapes( return self.batched_triton_experts.workspace_shapes(
a, M, N, K, topk, num_experts) a, aq, M, N, K, topk, num_experts)
def apply( def apply(
self, self,

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
""" CUTLASS based Fused MoE kernels.""" """ CUTLASS based Fused MoE kernels."""
from typing import Optional from typing import Callable, Optional
import torch import torch
@ -13,56 +13,24 @@ from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute): def run_cutlass_moe_fp8(
def __init__(
self,
ab_strides1: torch.Tensor,
c_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides2: torch.Tensor,
out_dtype: torch.dtype,
):
super().__init__()
self.ab_strides1 = ab_strides1
self.c_strides1 = c_strides1
self.ab_strides2 = ab_strides2
self.c_strides2 = c_strides2
self.out_dtype = out_dtype
def workspace_shapes(
self,
a: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
num_experts: int,
) -> tuple[int, int, torch.dtype]:
# Note that K, N are transposed
N, K = K, N
workspace1 = M * topk * max(2 * N, K)
workspace2 = M * topk * N
return (workspace1, workspace2, self.out_dtype)
def apply(
self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str, activation_callable: Callable,
global_num_experts: int, global_num_experts: int,
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor], w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor, workspace13: torch.Tensor,
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor], expert_num_tokens: Optional[torch.Tensor],
out_dtype: torch.dtype,
per_act_token: bool,
per_out_ch: bool,
) -> torch.Tensor: ) -> torch.Tensor:
a1q = hidden_states a1q = hidden_states
@ -70,38 +38,50 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
assert w2_scale is not None assert w2_scale is not None
assert w1.dtype == torch.float8_e4m3fn assert w1.dtype == torch.float8_e4m3fn
assert w2.dtype == torch.float8_e4m3fn assert w2.dtype == torch.float8_e4m3fn
assert a1q.shape[1] == w1.shape[1], "Hidden size mismatch w1" if expert_num_tokens is None:
assert w1.shape[2] == w2.shape[1] * 2, "Hidden size mismatch w2" assert a1q.shape[1] == w1.shape[2], "Hidden size mismatch w1"
else:
assert a1q.shape[2] == w1.shape[2], "Hidden size mismatch w1"
assert w1.shape[1] == w2.shape[2] * 2, "Hidden size mismatch w2"
assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[
1] == w1.shape[1], "W1 scale shape mismatch"
assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[
1] == w2.shape[1], "W2 scale shape mismatch"
assert w1.shape[0] == w2.shape[0], "Expert number mismatch" assert w1.shape[0] == w2.shape[0], "Expert number mismatch"
assert a1q_scale is None or a1q_scale.dim( assert a1q_scale is None or a1q_scale.dim(
) == 0 or a1q_scale.shape[0] == 1 or a1q_scale.shape[0] == a1q.shape[ ) == 0 or a1q_scale.shape[0] == 1 or a1q_scale.shape[0] == a1q.shape[
0], "Input scale shape mismatch" 0], "Input scale shape mismatch"
assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[
1] == w1.shape[2], "W1 scale shape mismatch"
assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[
1] == w2.shape[2], "W2 scale shape mismatch"
assert w1.shape[0] == w2.shape[0], "Weights expert number mismatch" assert w1.shape[0] == w2.shape[0], "Weights expert number mismatch"
assert w1.shape[0] == w1_scale.shape[ assert w1.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
0], "w1 scales expert number mismatch" assert w1.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
assert w1.shape[0] == w2_scale.shape[ assert a2_scale is None or a2_scale.dim(
0], "w2 scales expert number mismatch" ) == 0 or a2_scale.shape[0] == 1 or a2_scale.shape[0] == a1q.shape[
assert a2_scale is None or a1q_scale is None or a2_scale.shape == a1q_scale.shape, "Intermediate scale shape mismatch" # noqa: E501 0], "Intermediate scale shape mismatch"
assert self.ab_strides1.shape[0] == w1.shape[ assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
0], "AB Strides 1 expert number mismatch" if expert_map is not None:
assert self.c_strides1.shape[0] == w1.shape[ assert expert_num_tokens is None
0], "C Strides 1 expert number mismatch"
assert self.ab_strides2.shape[0] == w2.shape[
0], "AB Strides 2 expert number mismatch"
assert self.c_strides2.shape[0] == w2.shape[
0], "C Strides 2 expert number mismatch"
assert self.out_dtype in [torch.half,
torch.bfloat16], "Invalid output dtype"
M = a1q.shape[0] # We have two modes: PPLX and non-PPLX. We differentiate them by checking
_, N, K = w2.shape # because w1 + w2 are transposed # if expert_num_tokens is None (expert_num_tokens is a tensor which PPLX
# uses to track the number of tokens per expert).
# In the non-PPLX mode, the input tokens are not padded: thus, the shape
# of the input is [total_num_tokens, hidden_size]. The input and output
# require shuffling by a_map and c_map such that the tokens assigned to
# each expert are contiguous.
# In the PPLX mode, the input tokens are padded per expert to ensure that
# the PPLX dispatch and combine functions work correctly: thus, the shape
# of the input is [num_experts, max_num_tokens_per_expert, hidden_size].
# The PPLX input and output require no shuffling by a_map and c_map since
# their tokens are already contiguous for each expert as a result of
# the dispatch function.
is_pplx = expert_num_tokens is not None
M = a1q.shape[0] # no pplx
padded_M = a1q.shape[1] # pplx
_, K, N = w2.shape
device = a1q.device device = a1q.device
assert w1.shape[1] == K assert w1.shape[2] == K
assert global_num_experts != -1 assert global_num_experts != -1
assert a1q_scale is not None assert a1q_scale is not None
@ -113,10 +93,29 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
local_topk_ids = topk_ids local_topk_ids = topk_ids
topk = local_topk_ids.shape[1] topk = local_topk_ids.shape[1]
local_E = w1.shape[0]
per_act_token = a1q_scale.numel() != 1 if a1q_scale is not None else ( if is_pplx:
a2_scale.numel() != 1 if a2_scale is not None else False) expert_offsets = torch.empty((local_E),
dtype=torch.int32,
device=device)
problem_sizes1 = torch.empty((local_E, 3),
dtype=torch.int32,
device=device)
problem_sizes2 = torch.empty((local_E, 3),
dtype=torch.int32,
device=device)
ops.get_cutlass_pplx_moe_mm_data(expert_offsets, problem_sizes1,
problem_sizes2, expert_num_tokens,
local_E, padded_M, N, K)
w1_scale = w1_scale.reshape(w1_scale.shape[0], -1)
w2_scale = w2_scale.reshape(w2_scale.shape[0], -1)
a1q = a1q.reshape(-1, a1q.shape[2])
a1q_scale = a1q_scale.reshape(-1, a1q_scale.shape[2]).contiguous()
else:
expert_offsets = torch.empty((global_num_experts + 1), expert_offsets = torch.empty((global_num_experts + 1),
dtype=torch.int32, dtype=torch.int32,
device=device) device=device)
@ -149,16 +148,39 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
a1q = _fp8_perm(a1q, a_map) a1q = _fp8_perm(a1q, a_map)
a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale
expert_offsets = expert_offsets[:-1]
ab_strides1 = torch.full((w1.shape[0], ),
K,
device=device,
dtype=torch.int64)
c_strides1 = torch.full((w1.shape[0], ),
2 * N,
device=device,
dtype=torch.int64)
ab_strides2 = torch.full((w1.shape[0], ),
N,
device=device,
dtype=torch.int64)
c_strides2 = torch.full((w1.shape[0], ),
K,
device=device,
dtype=torch.int64)
if is_pplx:
c1 = _resize_cache(workspace13, (local_E * padded_M, N * 2))
c2 = _resize_cache(workspace2, (local_E * padded_M, N))
c3 = _resize_cache(workspace13, (local_E * padded_M, K))
else:
c1 = _resize_cache(workspace13, (M * topk, N * 2)) c1 = _resize_cache(workspace13, (M * topk, N * 2))
c2 = _resize_cache(workspace2, (M * topk, N)) c2 = _resize_cache(workspace2, (M * topk, N))
c3 = _resize_cache(workspace13, (M * topk, K)) c3 = _resize_cache(workspace13, (M * topk, K))
ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale, ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale, expert_offsets,
expert_offsets[:-1], problem_sizes1, problem_sizes1, ab_strides1, ab_strides1, c_strides1,
self.ab_strides1, self.ab_strides1, self.c_strides1) per_act_token, per_out_ch)
self.activation(activation, c2, c1) activation_callable(c2, c1)
a2q, a2q_scale = ops.scaled_fp8_quant( a2q, a2q_scale = ops.scaled_fp8_quant(
c2, a2_scale, use_per_token_if_dynamic=per_act_token) c2, a2_scale, use_per_token_if_dynamic=per_act_token)
@ -166,33 +188,90 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
if expert_map is not None: if expert_map is not None:
c3.fill_(0) c3.fill_(0)
ops.cutlass_moe_mm(c3, a2q, w2, a2q_scale, w2_scale, ops.cutlass_moe_mm(c3, a2q, w2, a2q_scale, w2_scale, expert_offsets,
expert_offsets[:-1], problem_sizes2, problem_sizes2, ab_strides2, ab_strides2, c_strides2,
self.ab_strides2, self.ab_strides2, self.c_strides2) per_act_token, per_out_ch)
c3 = c3[c_map] if is_pplx:
return c3.reshape(local_E, padded_M, K)
return c3 else:
return c3[c_map].view(M, topk, K)
class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
max_experts_per_worker: int,
out_dtype: torch.dtype,
per_act_token: bool,
per_out_ch: bool,
):
super().__init__()
self.max_experts_per_worker = max_experts_per_worker
self.out_dtype = out_dtype
self.per_act_token = per_act_token
self.per_out_ch = per_out_ch
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
num_experts: int,
) -> tuple[int, int, torch.dtype]:
padded_M = aq.shape[1]
workspace1 = self.max_experts_per_worker * padded_M * max(N, K)
workspace2 = self.max_experts_per_worker * padded_M * (N // 2)
return (workspace1, workspace2, self.out_dtype)
def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor:
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
activation_callable = lambda i, o: self.activation(activation, i, o)
return run_cutlass_moe_fp8(hidden_states, w1, w2, topk_ids,
activation_callable, global_num_experts,
expert_map, w1_scale, w2_scale, a1q_scale,
a2_scale, workspace13, workspace2,
expert_num_tokens, self.out_dtype,
self.per_act_token, self.per_out_ch)
#TODO make the grouped gemm kernel consistent with scaled gemm kernel
def cutlass_moe_fp8( def cutlass_moe_fp8(
a: torch.Tensor, a: torch.Tensor,
w1_q: torch.Tensor, w1_q: torch.Tensor,
w2_q: torch.Tensor, w2_q: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
ab_strides1: torch.Tensor, w1_scale: torch.Tensor,
c_strides1: torch.Tensor, w2_scale: torch.Tensor,
ab_strides2: torch.Tensor, activation: str = "silu",
c_strides2: torch.Tensor,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
out_dtype: torch.dtype = torch.half,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function computes a a8w8-quantized Mixture of Experts (MoE) layer This function computes a a8w8-quantized Mixture of Experts (MoE) layer
@ -207,25 +286,17 @@ def cutlass_moe_fp8(
Shape: [num_experts, K, 2N] (the weights are passed transposed) Shape: [num_experts, K, 2N] (the weights are passed transposed)
- w2_q (torch.Tensor): The second set of fp8-quantized expert weights. - w2_q (torch.Tensor): The second set of fp8-quantized expert weights.
Shape: [num_experts, N, K] (the weights are passed transposed) Shape: [num_experts, N, K] (the weights are passed transposed)
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
- topk_ids (torch.Tensor): The token->expert mappings.
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q. - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
Shape: [num_experts] or [num_experts, 2N] Shape: [num_experts] or [num_experts, 2N]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts] or [num_experts, K] Shape: [num_experts] or [num_experts, K]
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
- ab_strides1 (torch.Tensor): The input and weights strides of the first
grouped gemm.
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
- ab_strides2 (torch.Tensor): The input and weights strides of the second
grouped gemm.
- c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a. - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [M] Shape: scalar or [M]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
quantize the intermediate result between the gemms. quantize the intermediate result between the gemms.
Shape: scalar or [M] Shape: scalar or [M]
- out_dtype (torch.dtype): The output tensor type.
- expert_map (Optional[torch.Tensor]): In the case of Expert parallel, - expert_map (Optional[torch.Tensor]): In the case of Expert parallel,
every Rank is responsible for a subset of experts. expert_map is a every Rank is responsible for a subset of experts. expert_map is a
mapping from global expert-id to local expert-id. When expert_map[i] mapping from global expert-id to local expert-id. When expert_map[i]
@ -233,24 +304,27 @@ def cutlass_moe_fp8(
expert-id i. expert-id i.
- apply_router_weight_on_input (bool): When true, the topk weights are - apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is 1. applied directly on the inputs. This is only applicable when topk is 1.
- global_num_experts (int): The total number of experts.
Returns: Returns:
- torch.Tensor: The fp16 output tensor after applying the MoE layer. - torch.Tensor: The fp16 output tensor after applying the MoE layer.
""" """
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False) a2_scale.numel() != 1 if a2_scale is not None else False)
per_out_ch = w1_scale.numel() != w1_q.shape[0]
out_dtype = a.dtype
fn = mk.FusedMoEModularKernel( fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP( MoEPrepareAndFinalizeNoEP(
per_channel_quant=per_act_token,
quant_dtype=torch.float8_e4m3fn, quant_dtype=torch.float8_e4m3fn,
per_channel_quant=per_act_token,
), ),
CutlassExpertsFp8( CutlassExpertsFp8(
ab_strides1, max_experts_per_worker=global_num_experts,
c_strides1, out_dtype=out_dtype,
ab_strides2, per_act_token=per_act_token,
c_strides2, per_out_ch=per_out_ch,
out_dtype,
), ),
) )
@ -260,9 +334,12 @@ def cutlass_moe_fp8(
w2_q, w2_q,
topk_weights, topk_weights,
topk_ids, topk_ids,
expert_map=expert_map, False,
w1_scale=w1_scale, activation,
w2_scale=w2_scale, global_num_experts if global_num_experts != -1 else w1_q.size(0),
expert_map,
w1_scale,
w2_scale,
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,

View File

@ -73,6 +73,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor, a: torch.Tensor,
aq: torch.Tensor,
M: int, M: int,
N: int, N: int,
K: int, K: int,

View File

@ -521,6 +521,7 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor, a: torch.Tensor,
aq: torch.Tensor,
M: int, M: int,
N: int, N: int,
K: int, K: int,
@ -632,6 +633,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor, a: torch.Tensor,
aq: torch.Tensor,
M: int, M: int,
N: int, N: int,
K: int, K: int,

View File

@ -1545,6 +1545,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor, a: torch.Tensor,
aq: torch.Tensor,
M: int, M: int,
N: int, N: int,
K: int, K: int,

View File

@ -9,6 +9,9 @@ from typing import Callable, Optional, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy,
QuantizationType)
from torch.nn.parameter import UninitializedParameter from torch.nn.parameter import UninitializedParameter
import vllm.envs as envs import vllm.envs as envs
@ -210,6 +213,7 @@ class MoEConfig:
moe_parallel_config: FusedMoEParallelConfig moe_parallel_config: FusedMoEParallelConfig
in_dtype: torch.dtype # The activation type. in_dtype: torch.dtype # The activation type.
quant_dtype: torch.dtype = None
# TODO: add more quantization params, blocked, per-token, etc. # TODO: add more quantization params, blocked, per-token, etc.
block_size: int = 128 block_size: int = 128
@ -264,8 +268,22 @@ class FusedMoeWeightScaleSupported(Enum):
BLOCK = "block" BLOCK = "block"
def get_quant_config_input_activations(
quant_config: Optional[QuantizationConfig]
) -> Optional[QuantizationArgs]:
if (quant_config is not None and hasattr(quant_config, 'target_scheme_map')
and "Linear" in quant_config.target_scheme_map and
"input_activations" in quant_config.target_scheme_map["Linear"]):
return quant_config.target_scheme_map["Linear"].get(
"input_activations")
else:
return None
class FusedMoEMethodBase(QuantizeMethodBase): class FusedMoEMethodBase(QuantizeMethodBase):
moe: MoEConfig
@abstractmethod @abstractmethod
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int, hidden_size: int, intermediate_size_per_partition: int,
@ -277,6 +295,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
all2all_manager = get_ep_group().device_communicator.all2all_manager all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None assert all2all_manager is not None
self.moe = moe
quant_dtype = None quant_dtype = None
act_quant_block_size = None act_quant_block_size = None
from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.quantization.fp8 import Fp8Config
@ -297,12 +316,13 @@ class FusedMoEMethodBase(QuantizeMethodBase):
# dp_size actually means tp_size, bug in pplx kernels # dp_size actually means tp_size, bug in pplx kernels
dp_size=all2all_manager.tp_group.world_size, dp_size=all2all_manager.tp_group.world_size,
hidden_dim=moe.hidden_dim, hidden_dim=moe.hidden_dim,
hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, hidden_dim_bytes=moe.hidden_dim * moe.quant_dtype.itemsize,
# For blocked per token: set to # For blocked per token: set to
# ceil_div(hidden_dim, block_size) * sizeof(float32) # ceil_div(hidden_dim, block_size) * sizeof(float32)
# For per-token: set to sizeof(float32) # For per-token: set to sizeof(float32)
hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else ( hidden_dim_scale_bytes=(
(moe.hidden_dim + moe.block_size - 1) // moe.block_size * 0 if moe.quant_dtype.itemsize != 1 else
((moe.hidden_dim + moe.block_size - 1) // moe.block_size *
torch.float32.itemsize)), torch.float32.itemsize)),
) )
@ -313,6 +333,9 @@ class FusedMoEMethodBase(QuantizeMethodBase):
handle = all2all_manager.get_handle(all_to_all_args) handle = all2all_manager.get_handle(all_to_all_args)
input_activations = get_quant_config_input_activations(
quant_config)
prepare_finalize = PplxPrepareAndFinalize( prepare_finalize = PplxPrepareAndFinalize(
handle, handle,
max_num_tokens=moe.max_num_tokens, max_num_tokens=moe.max_num_tokens,
@ -320,7 +343,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
rank=all2all_manager.rank, rank=all2all_manager.rank,
# dp_size actually means tp_size, bug in pplx kernels # dp_size actually means tp_size, bug in pplx kernels
dp_size=all2all_manager.tp_group.world_size, dp_size=all2all_manager.tp_group.world_size,
quant_dtype=moe.in_dtype, quant_dtype=moe.quant_dtype,
per_act_token=(input_activations.strategy
== QuantizationStrategy.TOKEN
if input_activations is not None else False),
) )
elif moe.use_deepep_ht_kernels: elif moe.use_deepep_ht_kernels:
assert moe.dp_size == all2all_manager.dp_world_size assert moe.dp_size == all2all_manager.dp_world_size
@ -365,15 +391,15 @@ class FusedMoEMethodBase(QuantizeMethodBase):
self.topk_indices_dtype = None self.topk_indices_dtype = None
if prepare_finalize is not None: if prepare_finalize is not None:
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype() self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
experts = self.select_gemm_impl(prepare_finalize) experts = self.select_gemm_impl(prepare_finalize, moe)
self.fused_experts = FusedMoEModularKernel( self.fused_experts = FusedMoEModularKernel(
prepare_finalize, prepare_finalize,
experts, experts,
) )
def select_gemm_impl( def select_gemm_impl(
self, prepare_finalize: FusedMoEPrepareAndFinalize self, prepare_finalize: FusedMoEPrepareAndFinalize,
) -> FusedMoEPermuteExpertsUnpermute: moe: Optional[MoEConfig]) -> FusedMoEPermuteExpertsUnpermute:
# based on the all2all implementation, select the appropriate # based on the all2all implementation, select the appropriate
# gemm implementation # gemm implementation
raise NotImplementedError( raise NotImplementedError(
@ -419,7 +445,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
else: else:
self.rocm_aiter_fused_experts = None # type: ignore self.rocm_aiter_fused_experts = None # type: ignore
def select_gemm_impl(self, prepare_finalize: FusedMoEPrepareAndFinalize): def select_gemm_impl(self, prepare_finalize: FusedMoEPrepareAndFinalize,
moe: Optional[MoEConfig]):
assert self.fused_experts == fused_experts assert self.fused_experts == fused_experts
@ -809,7 +836,6 @@ class FusedMoE(torch.nn.Module):
activation: str = "silu", activation: str = "silu",
): ):
super().__init__() super().__init__()
if params_dtype is None: if params_dtype is None:
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype self.params_dtype = params_dtype
@ -869,14 +895,24 @@ class FusedMoE(torch.nn.Module):
from vllm_hpu_extension.ops import DynamicFusedMOE from vllm_hpu_extension.ops import DynamicFusedMOE
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts) self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
# Only support float8 for now.
quant_dtype = params_dtype
if quant_config is not None:
input_activations = get_quant_config_input_activations(
quant_config)
if (input_activations is not None
and input_activations.num_bits == 8
and input_activations.type == QuantizationType.FLOAT):
quant_dtype = torch.float8_e4m3fn
moe = MoEConfig( moe = MoEConfig(
num_experts=self.global_num_experts, num_experts=self.global_num_experts,
experts_per_token=top_k, experts_per_token=top_k,
hidden_dim=hidden_size, hidden_dim=hidden_size,
num_local_experts=self.local_num_experts, num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config, moe_parallel_config=self.moe_parallel_config,
# TODO (bnell): this needs to be fixed for quantized types.
in_dtype=params_dtype, in_dtype=params_dtype,
quant_dtype=quant_dtype,
max_num_tokens=MOE_DP_CHUNK_SIZE, max_num_tokens=MOE_DP_CHUNK_SIZE,
) )
self.moe_config = moe self.moe_config = moe

View File

@ -175,6 +175,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor, a: torch.Tensor,
aq: torch.Tensor,
M: int, M: int,
N: int, N: int,
K: int, K: int,
@ -309,7 +310,7 @@ class FusedMoEModularKernel(torch.nn.Module):
# Use a1 here to decipher the correct workspace datatype # Use a1 here to decipher the correct workspace datatype
workspace13_shape, workspace2_shape, workspace_dtype = ( workspace13_shape, workspace2_shape, workspace_dtype = (
self.fused_experts.workspace_shapes(a1, M, N, K, top_k, self.fused_experts.workspace_shapes(a1, a1q, M, N, K, top_k,
global_num_experts)) global_num_experts))
# We can reuse the memory between cache1 and cache3 because by the time # We can reuse the memory between cache1 and cache3 because by the time

View File

@ -21,7 +21,8 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
rank: int, rank: int,
dp_size: int, dp_size: int,
quant_dtype: Optional[torch.dtype] = None, quant_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None): block_shape: Optional[list[int]] = None,
per_act_token: bool = False):
super().__init__() super().__init__()
assert max_num_tokens > 0 assert max_num_tokens > 0
self.a2a = a2a self.a2a = a2a
@ -31,6 +32,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
self.rank = rank self.rank = rank
self.dp_size = dp_size self.dp_size = dp_size
self.quant_dtype = quant_dtype self.quant_dtype = quant_dtype
self.per_act_token = per_act_token
def max_num_tokens_per_rank(self) -> Optional[int]: def max_num_tokens_per_rank(self) -> Optional[int]:
return self.max_num_tokens return self.max_num_tokens
@ -66,13 +68,14 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
"apply_router_weight_on_input is only implemented for topk=1") "apply_router_weight_on_input is only implemented for topk=1")
a1 = a1 * rank_topk_weights.to(a1.dtype) a1 = a1 * rank_topk_weights.to(a1.dtype)
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( repeat_cols = 4
a2_scale.numel() != 1 if a2_scale is not None else False) repeat_rows = 1 if self.per_act_token else a1.shape[0]
a1q, a1q_scale = moe_kernel_quantize_input(
a1, (None if self.per_act_token else a1_scale), self.quant_dtype,
self.per_act_token, self.block_shape)
a1q, a1q_scale = moe_kernel_quantize_input(a1, a1_scale, if a1q_scale is not None:
self.quant_dtype, a1q_scale = a1q_scale.repeat(repeat_rows, repeat_cols)
per_act_token,
self.block_shape)
# rem_experts need to be 0 for pplx to work properly. # rem_experts need to be 0 for pplx to work properly.
rem_experts = num_experts % self.world_size rem_experts = num_experts % self.world_size
@ -100,7 +103,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
else 1) * float32_size else 1) * float32_size
expert_x_scale = torch.empty( expert_x_scale = torch.empty(
( (
num_experts, num_local_experts,
expert_x.size(1), expert_x.size(1),
(expert_x.size(2) + block_size - 1) // block_size, (expert_x.size(2) + block_size - 1) // block_size,
), ),
@ -121,6 +124,8 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
indices=rank_topk_ids, indices=rank_topk_ids,
bound_m=bound_m, bound_m=bound_m,
) )
if expert_x_scale is not None:
expert_x_scale = expert_x_scale[:, :, 0:1]
return expert_x, expert_x_scale, expert_num_tokens, None, None return expert_x, expert_x_scale, expert_num_tokens, None, None

View File

@ -37,6 +37,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def workspace_shapes( def workspace_shapes(
self, self,
a: torch.Tensor, a: torch.Tensor,
aq: torch.Tensor,
M: int, M: int,
N: int, N: int,
K: int, K: int,
@ -49,9 +50,9 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K): if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K):
assert self.deep_gemm_expert is not None assert self.deep_gemm_expert is not None
return self.deep_gemm_expert.workspace_shapes( return self.deep_gemm_expert.workspace_shapes(
a, M, N, K, topk, num_experts) a, aq, M, N, K, topk, num_experts)
else: else:
return self.triton_expert.workspace_shapes(a, M, N, K, topk, return self.triton_expert.workspace_shapes(a, aq, M, N, K, topk,
num_experts) num_experts)
def apply( def apply(

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum import enum
import importlib
from enum import Enum from enum import Enum
from typing import Callable, Optional from typing import Callable, Optional
@ -11,7 +12,6 @@ from compressed_tensors.quantization import (ActivationOrdering,
QuantizationStrategy) QuantizationStrategy)
import vllm.envs as envs import vllm.envs as envs
import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
@ -30,6 +30,15 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
if current_platform.is_cuda_alike():
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize)
if has_pplx:
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -77,8 +86,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
else: else:
logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod") logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
return CompressedTensorsWNA16MarlinMoEMethod(quant_config) return CompressedTensorsWNA16MarlinMoEMethod(quant_config)
elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) elif quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant):
and layer.activation == "silu"):
return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config) return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config)
elif quant_config._is_fp8_w8a8(weight_quant, input_quant): elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Fp8MoEMethod(quant_config) return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
@ -421,6 +429,11 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
"For FP8 Fused MoE layer, we require either per tensor or " "For FP8 Fused MoE layer, we require either per tensor or "
"channelwise, dynamic per token quantization.") "channelwise, dynamic per token quantization.")
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp8)
self.fused_experts = cutlass_moe_fp8 # type: ignore
self.disable_expert_map = False
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int, hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs): params_dtype: torch.dtype, **extra_weight_attrs):
@ -499,25 +512,6 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
layer.w13_input_scale = None layer.w13_input_scale = None
layer.w2_input_scale = None layer.w2_input_scale = None
device = w13_weight.device
# TODO strides can be shared across multiple layers
self.ab_strides1 = torch.full((num_experts, ),
hidden_size,
device=device,
dtype=torch.int64)
self.c_strides1 = torch.full((num_experts, ),
2 * intermediate_size_per_partition,
device=device,
dtype=torch.int64)
self.ab_strides2 = torch.full((num_experts, ),
intermediate_size_per_partition,
device=device,
dtype=torch.int64)
self.c_strides2 = torch.full((num_experts, ),
hidden_size,
device=device,
dtype=torch.int64)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Fp8 moe kernels require a single activation scale. # Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ. # We take the max of all the scales in case they differ.
@ -558,6 +552,27 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
requires_grad=False) requires_grad=False)
def select_gemm_impl(self, prepare_finalize, moe):
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsFp8)
assert moe is not None
max_experts_per_worker = (
(moe.num_experts + prepare_finalize.world_size - 1) //
prepare_finalize.world_size)
experts = CutlassExpertsFp8(
max_experts_per_worker, moe.in_dtype,
self.input_quant.strategy == QuantizationStrategy.TOKEN,
self.weight_quant.strategy == QuantizationStrategy.CHANNEL)
if has_pplx and isinstance(
prepare_finalize,
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
# no expert_map support in this case
self.disable_expert_map = True
return experts
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
@ -577,9 +592,6 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
activation: str = "silu", activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
assert activation == "silu", (
f"{activation} not supported for Cutlass MoE.")
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
@ -590,27 +602,22 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias,
indices_type=torch.uint32)
from vllm.model_executor.layers.fused_moe import cutlass_moe_fp8 return self.fused_experts(
return cutlass_moe_fp8(
x, x,
layer.w13_weight.transpose(1, 2), layer.w13_weight,
layer.w2_weight.transpose(1, 2), layer.w2_weight,
layer.w13_weight_scale,
layer.w2_weight_scale,
topk_weights, topk_weights,
topk_ids, topk_ids,
self.ab_strides1, activation=activation,
self.c_strides1, global_num_experts=global_num_experts,
self.ab_strides2, expert_map=None if self.disable_expert_map else expert_map,
self.c_strides2, w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
out_dtype=x.dtype,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
) )

View File

@ -769,7 +769,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
del layer.w13_input_scale del layer.w13_input_scale
del layer.w2_input_scale del layer.w2_input_scale
def select_gemm_impl(self, prepare_finalize): def select_gemm_impl(self, prepare_finalize, moe):
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
BatchedTritonOrDeepGemmExperts) BatchedTritonOrDeepGemmExperts)