mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 22:45:50 +08:00
[Kernel] Integrate CUTLASS MoE kernel with PPLX (#18762)
Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
6e0cd10f72
commit
84166fee97
@ -543,8 +543,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
# CUTLASS MoE kernels
|
# CUTLASS MoE kernels
|
||||||
|
|
||||||
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works
|
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works
|
||||||
# on Hopper). get_cutlass_moe_mm_data should only be compiled if it's possible
|
# on Hopper). get_cutlass_(pplx_)moe_mm_data should only be compiled
|
||||||
# to compile MoE kernels that use its output.
|
# if it's possible to compile MoE kernels that use its output.
|
||||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}")
|
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}")
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
|
||||||
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu"
|
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu"
|
||||||
|
|||||||
@ -7,8 +7,8 @@ from benchmark_shapes import WEIGHT_SHAPES_MOE
|
|||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||||
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||||
cutlass_moe_fp8,
|
|
||||||
fused_experts,
|
fused_experts,
|
||||||
fused_topk,
|
fused_topk,
|
||||||
)
|
)
|
||||||
@ -70,18 +70,9 @@ def bench_run(
|
|||||||
w1_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32)
|
w1_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32)
|
||||||
w2_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32)
|
w2_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32)
|
||||||
|
|
||||||
ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
|
|
||||||
c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
|
|
||||||
ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
|
|
||||||
c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
|
|
||||||
|
|
||||||
for expert in range(num_experts):
|
for expert in range(num_experts):
|
||||||
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert])
|
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert])
|
||||||
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert])
|
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert])
|
||||||
w1_q_notransp = w1_q.clone()
|
|
||||||
w2_q_notransp = w2_q.clone()
|
|
||||||
w1_q = w1_q.transpose(1, 2)
|
|
||||||
w2_q = w2_q.transpose(1, 2)
|
|
||||||
|
|
||||||
score = torch.randn((m, num_experts), device="cuda", dtype=dtype)
|
score = torch.randn((m, num_experts), device="cuda", dtype=dtype)
|
||||||
|
|
||||||
@ -122,10 +113,6 @@ def bench_run(
|
|||||||
w2_scale: torch.Tensor,
|
w2_scale: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
ab_strides1: torch.Tensor,
|
|
||||||
c_strides1: torch.Tensor,
|
|
||||||
ab_strides2: torch.Tensor,
|
|
||||||
c_strides2: torch.Tensor,
|
|
||||||
num_repeats: int,
|
num_repeats: int,
|
||||||
):
|
):
|
||||||
for _ in range(num_repeats):
|
for _ in range(num_repeats):
|
||||||
@ -133,14 +120,10 @@ def bench_run(
|
|||||||
a,
|
a,
|
||||||
w1,
|
w1,
|
||||||
w2,
|
w2,
|
||||||
w1_scale,
|
|
||||||
w2_scale,
|
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
ab_strides1,
|
w1_scale,
|
||||||
c_strides1,
|
w2_scale,
|
||||||
ab_strides2,
|
|
||||||
c_strides2,
|
|
||||||
a1_scale=a_scale,
|
a1_scale=a_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -153,10 +136,6 @@ def bench_run(
|
|||||||
w2_scale: torch.Tensor,
|
w2_scale: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
ab_strides1: torch.Tensor,
|
|
||||||
c_strides1: torch.Tensor,
|
|
||||||
ab_strides2: torch.Tensor,
|
|
||||||
c_strides2: torch.Tensor,
|
|
||||||
):
|
):
|
||||||
with set_current_vllm_config(
|
with set_current_vllm_config(
|
||||||
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||||
@ -165,14 +144,10 @@ def bench_run(
|
|||||||
a,
|
a,
|
||||||
w1_q,
|
w1_q,
|
||||||
w2_q,
|
w2_q,
|
||||||
w1_scale,
|
|
||||||
w2_scale,
|
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
ab_strides1,
|
w1_scale,
|
||||||
c_strides1,
|
w2_scale,
|
||||||
ab_strides2,
|
|
||||||
c_strides2,
|
|
||||||
a1_scale=a_scale,
|
a1_scale=a_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -218,10 +193,6 @@ def bench_run(
|
|||||||
w2_scale,
|
w2_scale,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
ab_strides1,
|
|
||||||
c_strides1,
|
|
||||||
ab_strides2,
|
|
||||||
c_strides2,
|
|
||||||
)
|
)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
@ -230,8 +201,8 @@ def bench_run(
|
|||||||
with torch.cuda.graph(triton_graph, stream=triton_stream):
|
with torch.cuda.graph(triton_graph, stream=triton_stream):
|
||||||
run_triton_from_graph(
|
run_triton_from_graph(
|
||||||
a,
|
a,
|
||||||
w1_q_notransp,
|
w1_q,
|
||||||
w2_q_notransp,
|
w2_q,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
w1_scale,
|
w1_scale,
|
||||||
@ -250,18 +221,12 @@ def bench_run(
|
|||||||
"w2": w2,
|
"w2": w2,
|
||||||
"score": score,
|
"score": score,
|
||||||
"topk": topk,
|
"topk": topk,
|
||||||
"w1_q_notransp": w1_q_notransp,
|
|
||||||
"w2_q_notransp": w2_q_notransp,
|
|
||||||
# Cutlass params
|
# Cutlass params
|
||||||
"a_scale": a_scale,
|
"a_scale": a_scale,
|
||||||
"w1_q": w1_q,
|
"w1_q": w1_q,
|
||||||
"w2_q": w2_q,
|
"w2_q": w2_q,
|
||||||
"w1_scale": w1_scale,
|
"w1_scale": w1_scale,
|
||||||
"w2_scale": w2_scale,
|
"w2_scale": w2_scale,
|
||||||
"ab_strides1": ab_strides1,
|
|
||||||
"c_strides1": c_strides1,
|
|
||||||
"ab_strides2": ab_strides2,
|
|
||||||
"c_strides2": c_strides2,
|
|
||||||
# cuda graph params
|
# cuda graph params
|
||||||
"cutlass_graph": cutlass_graph,
|
"cutlass_graph": cutlass_graph,
|
||||||
"triton_graph": triton_graph,
|
"triton_graph": triton_graph,
|
||||||
@ -279,8 +244,8 @@ def bench_run(
|
|||||||
# Warmup
|
# Warmup
|
||||||
run_triton_moe(
|
run_triton_moe(
|
||||||
a,
|
a,
|
||||||
w1_q_notransp,
|
w1_q,
|
||||||
w2_q_notransp,
|
w2_q,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
w1_scale,
|
w1_scale,
|
||||||
@ -291,7 +256,7 @@ def bench_run(
|
|||||||
|
|
||||||
results.append(
|
results.append(
|
||||||
benchmark.Timer(
|
benchmark.Timer(
|
||||||
stmt="run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501
|
stmt="run_triton_moe(a, w1_q, w2_q, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501
|
||||||
globals=globals,
|
globals=globals,
|
||||||
label=label,
|
label=label,
|
||||||
sub_label=sub_label,
|
sub_label=sub_label,
|
||||||
@ -322,16 +287,12 @@ def bench_run(
|
|||||||
w2_scale,
|
w2_scale,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
ab_strides1,
|
|
||||||
c_strides1,
|
|
||||||
ab_strides2,
|
|
||||||
c_strides2,
|
|
||||||
num_warmup,
|
num_warmup,
|
||||||
)
|
)
|
||||||
|
|
||||||
results.append(
|
results.append(
|
||||||
benchmark.Timer(
|
benchmark.Timer(
|
||||||
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501
|
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, num_runs)", # noqa: E501
|
||||||
globals=globals,
|
globals=globals,
|
||||||
label=label,
|
label=label,
|
||||||
sub_label=sub_label,
|
sub_label=sub_label,
|
||||||
|
|||||||
11
csrc/ops.h
11
csrc/ops.h
@ -236,7 +236,8 @@ void cutlass_moe_mm(
|
|||||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides);
|
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||||
|
bool per_act_token, bool per_out_ch);
|
||||||
|
|
||||||
void cutlass_fp4_group_mm(
|
void cutlass_fp4_group_mm(
|
||||||
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
|
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
|
||||||
@ -251,6 +252,14 @@ void get_cutlass_moe_mm_data(
|
|||||||
const int64_t num_experts, const int64_t n, const int64_t k,
|
const int64_t num_experts, const int64_t n, const int64_t k,
|
||||||
const std::optional<torch::Tensor>& blockscale_offsets);
|
const std::optional<torch::Tensor>& blockscale_offsets);
|
||||||
|
|
||||||
|
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
|
||||||
|
torch::Tensor& problem_sizes1,
|
||||||
|
torch::Tensor& problem_sizes2,
|
||||||
|
const torch::Tensor& expert_num_tokens,
|
||||||
|
const int64_t num_local_experts,
|
||||||
|
const int64_t padded_m, const int64_t n,
|
||||||
|
const int64_t k);
|
||||||
|
|
||||||
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
|
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
|
|||||||
@ -84,7 +84,8 @@ void run_cutlass_moe_mm_sm90(
|
|||||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
|
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||||
|
bool per_act_token, bool per_out_ch) {
|
||||||
TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
|
TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
|
||||||
TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
|
TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
|
||||||
TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
|
TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
|
||||||
@ -113,19 +114,23 @@ void run_cutlass_moe_mm_sm90(
|
|||||||
if (n >= 8192) {
|
if (n >= 8192) {
|
||||||
cutlass_group_gemm_caller<Cutlass3xGemmN8192>(
|
cutlass_group_gemm_caller<Cutlass3xGemmN8192>(
|
||||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||||
problem_sizes, a_strides, b_strides, c_strides);
|
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||||
|
per_out_ch);
|
||||||
} else if (k >= 8192) {
|
} else if (k >= 8192) {
|
||||||
cutlass_group_gemm_caller<Cutlass3xGemmK8192>(
|
cutlass_group_gemm_caller<Cutlass3xGemmK8192>(
|
||||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||||
problem_sizes, a_strides, b_strides, c_strides);
|
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||||
|
per_out_ch);
|
||||||
} else if (m <= 16) {
|
} else if (m <= 16) {
|
||||||
cutlass_group_gemm_caller<Cutlass3xGemmM16>(
|
cutlass_group_gemm_caller<Cutlass3xGemmM16>(
|
||||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||||
problem_sizes, a_strides, b_strides, c_strides);
|
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||||
|
per_out_ch);
|
||||||
} else {
|
} else {
|
||||||
cutlass_group_gemm_caller<Cutlass3xGemmDefault>(
|
cutlass_group_gemm_caller<Cutlass3xGemmDefault>(
|
||||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||||
problem_sizes, a_strides, b_strides, c_strides);
|
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||||
|
per_out_ch);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -134,15 +139,18 @@ void dispatch_moe_mm_sm90(
|
|||||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
|
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||||
|
bool per_act_token, bool per_out_ch) {
|
||||||
if (out_tensors.dtype() == torch::kBFloat16) {
|
if (out_tensors.dtype() == torch::kBFloat16) {
|
||||||
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
|
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
|
||||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||||
problem_sizes, a_strides, b_strides, c_strides);
|
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||||
|
per_out_ch);
|
||||||
} else {
|
} else {
|
||||||
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::half_t>(
|
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::half_t>(
|
||||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||||
problem_sizes, a_strides, b_strides, c_strides);
|
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||||
|
per_out_ch);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -153,8 +161,9 @@ void cutlass_moe_mm_sm90(
|
|||||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
|
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||||
|
bool per_act_token, bool per_out_ch) {
|
||||||
dispatch_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
dispatch_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||||
expert_offsets, problem_sizes, a_strides, b_strides,
|
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||||
c_strides);
|
c_strides, per_act_token, per_out_ch);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -76,7 +76,8 @@ void cutlass_group_gemm_caller(
|
|||||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
|
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||||
|
bool per_act_token, bool per_out_ch) {
|
||||||
using ElementAB = typename Gemm::ElementAB;
|
using ElementAB = typename Gemm::ElementAB;
|
||||||
using ElementD = typename Gemm::ElementD;
|
using ElementD = typename Gemm::ElementD;
|
||||||
|
|
||||||
@ -84,9 +85,6 @@ void cutlass_group_gemm_caller(
|
|||||||
int k_size = a_tensors.size(1);
|
int k_size = a_tensors.size(1);
|
||||||
int n_size = out_tensors.size(1);
|
int n_size = out_tensors.size(1);
|
||||||
|
|
||||||
bool per_act_token = a_scales.numel() != 1;
|
|
||||||
bool per_out_ch = b_scales.numel() != num_experts;
|
|
||||||
|
|
||||||
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
||||||
|
|
||||||
auto options_int =
|
auto options_int =
|
||||||
|
|||||||
@ -7,7 +7,7 @@
|
|||||||
|
|
||||||
constexpr uint64_t THREADS_PER_EXPERT = 512;
|
constexpr uint64_t THREADS_PER_EXPERT = 512;
|
||||||
|
|
||||||
__global__ void compute_problem_sizes(const int* __restrict__ topk_ids,
|
__global__ void compute_problem_sizes(const uint32_t* __restrict__ topk_ids,
|
||||||
int32_t* problem_sizes1,
|
int32_t* problem_sizes1,
|
||||||
int32_t* problem_sizes2,
|
int32_t* problem_sizes2,
|
||||||
int32_t* atomic_buffer,
|
int32_t* atomic_buffer,
|
||||||
@ -62,7 +62,7 @@ __global__ void compute_expert_blockscale_offsets(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void compute_arg_sorts(const int* __restrict__ topk_ids,
|
__global__ void compute_arg_sorts(const uint32_t* __restrict__ topk_ids,
|
||||||
const int32_t* __restrict__ expert_offsets,
|
const int32_t* __restrict__ expert_offsets,
|
||||||
int32_t* input_permutation,
|
int32_t* input_permutation,
|
||||||
int32_t* output_permutation,
|
int32_t* output_permutation,
|
||||||
@ -103,7 +103,7 @@ void get_cutlass_moe_mm_data_caller(
|
|||||||
|
|
||||||
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
||||||
compute_problem_sizes<<<num_experts, num_threads, 0, stream>>>(
|
compute_problem_sizes<<<num_experts, num_threads, 0, stream>>>(
|
||||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
static_cast<const uint32_t*>(topk_ids.data_ptr()),
|
||||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n, k);
|
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n, k);
|
||||||
@ -120,10 +120,44 @@ void get_cutlass_moe_mm_data_caller(
|
|||||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
|
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
|
||||||
}
|
}
|
||||||
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
|
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
|
||||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
static_cast<const uint32_t*>(topk_ids.data_ptr()),
|
||||||
static_cast<const int32_t*>(expert_offsets.data_ptr()),
|
static_cast<const int32_t*>(expert_offsets.data_ptr()),
|
||||||
static_cast<int32_t*>(input_permutation.data_ptr()),
|
static_cast<int32_t*>(input_permutation.data_ptr()),
|
||||||
static_cast<int32_t*>(output_permutation.data_ptr()),
|
static_cast<int32_t*>(output_permutation.data_ptr()),
|
||||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(),
|
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(),
|
||||||
topk_ids.size(1));
|
topk_ids.size(1));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__global__ void compute_pplx_data(int32_t* expert_offsets,
|
||||||
|
int32_t* problem_sizes1,
|
||||||
|
int32_t* problem_sizes2,
|
||||||
|
const int32_t* __restrict__ expert_num_tokens,
|
||||||
|
const int padded_m, const int n,
|
||||||
|
const int k) {
|
||||||
|
int expert_idx = threadIdx.x;
|
||||||
|
|
||||||
|
expert_offsets[expert_idx] = expert_idx * padded_m;
|
||||||
|
problem_sizes1[expert_idx * 3] = expert_num_tokens[expert_idx];
|
||||||
|
problem_sizes1[expert_idx * 3 + 1] = 2 * n;
|
||||||
|
problem_sizes1[expert_idx * 3 + 2] = k;
|
||||||
|
problem_sizes2[expert_idx * 3] = expert_num_tokens[expert_idx];
|
||||||
|
problem_sizes2[expert_idx * 3 + 1] = k;
|
||||||
|
problem_sizes2[expert_idx * 3 + 2] = n;
|
||||||
|
}
|
||||||
|
|
||||||
|
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
|
||||||
|
torch::Tensor& problem_sizes1,
|
||||||
|
torch::Tensor& problem_sizes2,
|
||||||
|
const torch::Tensor& expert_num_tokens,
|
||||||
|
const int64_t num_local_experts,
|
||||||
|
const int64_t padded_m,
|
||||||
|
const int64_t n, const int64_t k) {
|
||||||
|
auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index());
|
||||||
|
|
||||||
|
compute_pplx_data<<<1, num_local_experts, 0, stream>>>(
|
||||||
|
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
||||||
|
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||||
|
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||||
|
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
|
||||||
|
k);
|
||||||
|
}
|
||||||
|
|||||||
@ -36,7 +36,8 @@ void cutlass_moe_mm_sm90(
|
|||||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides);
|
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||||
|
bool per_act_token, bool per_out_ch);
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -56,6 +57,14 @@ void get_cutlass_moe_mm_data_caller(
|
|||||||
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
||||||
const int64_t num_experts, const int64_t n, const int64_t k,
|
const int64_t num_experts, const int64_t n, const int64_t k,
|
||||||
const std::optional<torch::Tensor>& blockscale_offsets);
|
const std::optional<torch::Tensor>& blockscale_offsets);
|
||||||
|
|
||||||
|
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
|
||||||
|
torch::Tensor& problem_sizes1,
|
||||||
|
torch::Tensor& problem_sizes2,
|
||||||
|
const torch::Tensor& expert_num_tokens,
|
||||||
|
const int64_t num_local_experts,
|
||||||
|
const int64_t padded_m,
|
||||||
|
const int64_t n, const int64_t k);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
|
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
|
||||||
@ -207,12 +216,13 @@ void cutlass_moe_mm(
|
|||||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
|
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||||
|
bool per_act_token, bool per_out_ch) {
|
||||||
int32_t version_num = get_sm_version_num();
|
int32_t version_num = get_sm_version_num();
|
||||||
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
|
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
|
||||||
cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||||
expert_offsets, problem_sizes, a_strides, b_strides,
|
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||||
c_strides);
|
c_strides, per_act_token, per_out_ch);
|
||||||
return;
|
return;
|
||||||
#endif
|
#endif
|
||||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
@ -245,6 +255,29 @@ void get_cutlass_moe_mm_data(
|
|||||||
version_num, ". Required capability: 90");
|
version_num, ". Required capability: 90");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
|
||||||
|
torch::Tensor& problem_sizes1,
|
||||||
|
torch::Tensor& problem_sizes2,
|
||||||
|
const torch::Tensor& expert_num_tokens,
|
||||||
|
const int64_t num_local_experts,
|
||||||
|
const int64_t padded_m, const int64_t n,
|
||||||
|
const int64_t k) {
|
||||||
|
// This function currently gets compiled only if we have a valid cutlass moe
|
||||||
|
// mm to run it for.
|
||||||
|
int32_t version_num = get_sm_version_num();
|
||||||
|
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
|
||||||
|
get_cutlass_pplx_moe_mm_data_caller(expert_offsets, problem_sizes1,
|
||||||
|
problem_sizes2, expert_num_tokens,
|
||||||
|
num_local_experts, padded_m, n, k);
|
||||||
|
return;
|
||||||
|
#endif
|
||||||
|
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
|
false,
|
||||||
|
"No compiled get_cutlass_pplx_moe_mm_data: no cutlass_scaled_mm kernel "
|
||||||
|
"for CUDA device capability: ",
|
||||||
|
version_num, ". Required capability: 90");
|
||||||
|
}
|
||||||
|
|
||||||
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
|
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
|
|||||||
@ -435,7 +435,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
"cutlass_moe_mm(Tensor! out_tensors, Tensor a_tensors, Tensor b_tensors, "
|
"cutlass_moe_mm(Tensor! out_tensors, Tensor a_tensors, Tensor b_tensors, "
|
||||||
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
|
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
|
||||||
" Tensor problem_sizes, Tensor a_strides, "
|
" Tensor problem_sizes, Tensor a_strides, "
|
||||||
" Tensor b_strides, Tensor c_strides) -> ()",
|
" Tensor b_strides, Tensor c_strides, bool per_act_token, "
|
||||||
|
" bool per_out_ch) -> ()",
|
||||||
{stride_tag});
|
{stride_tag});
|
||||||
ops.impl("cutlass_moe_mm", torch::kCUDA, &cutlass_moe_mm);
|
ops.impl("cutlass_moe_mm", torch::kCUDA, &cutlass_moe_mm);
|
||||||
|
|
||||||
@ -454,6 +455,22 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
{stride_tag});
|
{stride_tag});
|
||||||
ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data);
|
ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data);
|
||||||
|
|
||||||
|
// A function that computes data required to run fused MoE with w8a8 grouped
|
||||||
|
// GEMM and PPLX. It takes expert_num_tokens and non_zero_expert_idxs
|
||||||
|
// as an input, and computes expert_offsets (token start indices of each
|
||||||
|
// expert). In addition to this, it computes problem sizes for each expert's
|
||||||
|
// multiplication used by the two mms called from fused MoE operation.
|
||||||
|
ops.def(
|
||||||
|
"get_cutlass_pplx_moe_mm_data(Tensor! expert_offsets, "
|
||||||
|
" Tensor! problem_sizes1, "
|
||||||
|
" Tensor! problem_sizes2, "
|
||||||
|
" Tensor expert_num_tokens, "
|
||||||
|
" int num_local_experts, int padded_m, "
|
||||||
|
" int n, int k) -> ()",
|
||||||
|
{stride_tag});
|
||||||
|
ops.impl("get_cutlass_pplx_moe_mm_data", torch::kCUDA,
|
||||||
|
&get_cutlass_pplx_moe_mm_data);
|
||||||
|
|
||||||
// Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3)
|
// Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3)
|
||||||
ops.def(
|
ops.def(
|
||||||
"cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> "
|
"cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> "
|
||||||
|
|||||||
@ -193,14 +193,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
|
|||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
'a': moe_tensors.a,
|
'a': moe_tensors.a,
|
||||||
'w1_q': moe_tensors.w1_q.transpose(1, 2), # type: ignore[union-attr]
|
'w1_q': moe_tensors.w1_q, # type: ignore[union-attr]
|
||||||
'w2_q': moe_tensors.w2_q.transpose(1, 2), # type: ignore[union-attr]
|
'w2_q': moe_tensors.w2_q, # type: ignore[union-attr]
|
||||||
'topk_weights': topk_weights,
|
'topk_weights': topk_weights,
|
||||||
'topk_ids': topk_ids,
|
'topk_ids': topk_ids,
|
||||||
'ab_strides1': moe_tensors.ab_strides1,
|
|
||||||
'c_strides1': moe_tensors.c_strides1,
|
|
||||||
'ab_strides2': moe_tensors.ab_strides2,
|
|
||||||
'c_strides2': moe_tensors.c_strides2,
|
|
||||||
'w1_scale': moe_tensors.w1_scale,
|
'w1_scale': moe_tensors.w1_scale,
|
||||||
'w2_scale': moe_tensors.w2_scale,
|
'w2_scale': moe_tensors.w2_scale,
|
||||||
'a1_scale': moe_tensors.a_scale
|
'a1_scale': moe_tensors.a_scale
|
||||||
|
|||||||
287
tests/kernels/moe/test_pplx_cutlass_moe.py
Normal file
287
tests/kernels/moe/test_pplx_cutlass_moe.py
Normal file
@ -0,0 +1,287 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tests.pplx_utils import ProcessGroupInfo, parallel_launch
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.config import VllmConfig, set_current_vllm_config
|
||||||
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||||
|
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||||
|
FusedMoEModularKernel)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pplx_kernels import AllToAll
|
||||||
|
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
|
||||||
|
nvshmem_finalize, nvshmem_get_unique_id,
|
||||||
|
nvshmem_init)
|
||||||
|
has_pplx = True
|
||||||
|
except ImportError:
|
||||||
|
has_pplx = False
|
||||||
|
|
||||||
|
requires_pplx = pytest.mark.skipif(
|
||||||
|
not has_pplx,
|
||||||
|
reason="Requires PPLX kernels",
|
||||||
|
)
|
||||||
|
|
||||||
|
NUM_EXPERTS = [40, 64]
|
||||||
|
TOP_KS = [6, 8]
|
||||||
|
|
||||||
|
|
||||||
|
def rank_chunk(num, r, w):
|
||||||
|
rem = num % w
|
||||||
|
return (num // w) + (1 if r < rem else 0)
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_by_rank(t, r, w):
|
||||||
|
num = t.shape[0]
|
||||||
|
chunk = rank_chunk(num, r, w)
|
||||||
|
rem = num % w
|
||||||
|
if rem == 0 or r < rem:
|
||||||
|
return t[(r * chunk):(r + 1) * chunk].contiguous()
|
||||||
|
else:
|
||||||
|
long_chunks = (num // w + 1) * rem
|
||||||
|
short_chunks = (r - rem) * chunk
|
||||||
|
start = long_chunks + short_chunks
|
||||||
|
return t[start:start + chunk].contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
def pplx_cutlass_moe(
|
||||||
|
pgi: ProcessGroupInfo,
|
||||||
|
dp_size: int,
|
||||||
|
a: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
w1_scale: torch.Tensor,
|
||||||
|
w2_scale: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
a1_scale: torch.Tensor,
|
||||||
|
out_dtype,
|
||||||
|
per_act_token: bool,
|
||||||
|
per_out_ch: bool,
|
||||||
|
):
|
||||||
|
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
||||||
|
PplxPrepareAndFinalize)
|
||||||
|
assert torch.cuda.current_device() == pgi.local_rank
|
||||||
|
|
||||||
|
num_tokens, hidden_dim = a.shape
|
||||||
|
num_experts = w1.shape[0]
|
||||||
|
block_size = hidden_dim # TODO support more cases
|
||||||
|
device = pgi.device
|
||||||
|
rank = pgi.rank
|
||||||
|
world_size = pgi.world_size
|
||||||
|
rank_num_tokens = rank_chunk(num_tokens, rank, world_size)
|
||||||
|
max_num_tokens = rank_chunk(num_tokens, 0, world_size)
|
||||||
|
topk = topk_ids.shape[1]
|
||||||
|
|
||||||
|
if block_size == hidden_dim:
|
||||||
|
scale_elems = 4 # hack to circumvent pplx data format requirements
|
||||||
|
else:
|
||||||
|
scale_elems = (hidden_dim + block_size - 1) // block_size
|
||||||
|
|
||||||
|
ata = AllToAll.internode(
|
||||||
|
max_num_tokens=max_num_tokens,
|
||||||
|
num_experts=num_experts,
|
||||||
|
experts_per_token=topk,
|
||||||
|
rank=rank,
|
||||||
|
world_size=pgi.world_size,
|
||||||
|
dp_size=dp_size,
|
||||||
|
hidden_dim=hidden_dim,
|
||||||
|
hidden_dim_bytes=hidden_dim, # because a.dtype.itemsize == 1
|
||||||
|
hidden_dim_scale_bytes=scale_elems * torch.float32.itemsize,
|
||||||
|
)
|
||||||
|
|
||||||
|
w1 = w1.to(device)
|
||||||
|
w2 = w2.to(device)
|
||||||
|
w1_scale = w1_scale.to(device)
|
||||||
|
w2_scale = w2_scale.to(device)
|
||||||
|
a1_scale = a1_scale.to(device)
|
||||||
|
|
||||||
|
prepare_finalize = PplxPrepareAndFinalize(
|
||||||
|
ata,
|
||||||
|
max_num_tokens,
|
||||||
|
pgi.world_size,
|
||||||
|
rank,
|
||||||
|
dp_size,
|
||||||
|
quant_dtype=torch.float8_e4m3fn,
|
||||||
|
per_act_token=per_act_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
experts = CutlassExpertsFp8((num_experts + world_size - 1) // world_size,
|
||||||
|
out_dtype, per_act_token, per_out_ch)
|
||||||
|
|
||||||
|
fused_cutlass_experts = FusedMoEModularKernel(
|
||||||
|
prepare_finalize,
|
||||||
|
experts,
|
||||||
|
)
|
||||||
|
|
||||||
|
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
|
||||||
|
chunk_topk_weight = chunk_by_rank(topk_weights, rank,
|
||||||
|
world_size).to(device)
|
||||||
|
chunk_topk_ids = chunk_by_rank(topk_ids, rank,
|
||||||
|
world_size).to(torch.uint32).to(device)
|
||||||
|
|
||||||
|
out = fused_cutlass_experts(
|
||||||
|
a_chunk,
|
||||||
|
chunk_by_rank(w1, rank, world_size),
|
||||||
|
chunk_by_rank(w2, rank, world_size),
|
||||||
|
chunk_topk_weight,
|
||||||
|
chunk_topk_ids,
|
||||||
|
global_num_experts=num_experts,
|
||||||
|
expert_map=None, #TODO
|
||||||
|
w1_scale=chunk_by_rank(w1_scale, rank, world_size),
|
||||||
|
w2_scale=chunk_by_rank(w2_scale, rank, world_size),
|
||||||
|
a1_scale=chunk_by_rank(a1_scale, rank, world_size)
|
||||||
|
if per_act_token else a1_scale[rank])
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
ata.destroy()
|
||||||
|
|
||||||
|
return out[:rank_num_tokens]
|
||||||
|
|
||||||
|
|
||||||
|
vllm_config = VllmConfig()
|
||||||
|
vllm_config.scheduler_config.max_num_seqs = 128
|
||||||
|
vllm_config.scheduler_config.max_model_len = 8192
|
||||||
|
|
||||||
|
|
||||||
|
def torch_moe2(a, w1, w2, topk_weight, topk_ids):
|
||||||
|
M, K = a.shape
|
||||||
|
topk = topk_ids.shape[1]
|
||||||
|
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
|
||||||
|
out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||||
|
num_experts = w1.shape[0]
|
||||||
|
for i in range(num_experts):
|
||||||
|
mask = (topk_ids == i).view(-1)
|
||||||
|
if mask.sum():
|
||||||
|
out[mask] = SiluAndMul()(
|
||||||
|
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
|
||||||
|
|
||||||
|
return (out.view(M, -1, w2.shape[1]) *
|
||||||
|
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
def _pplx_moe(
|
||||||
|
pgi: ProcessGroupInfo,
|
||||||
|
dp_size: int,
|
||||||
|
a: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
w1_scale: torch.Tensor,
|
||||||
|
w2_scale: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
a1_scale: torch.Tensor,
|
||||||
|
out_dtype,
|
||||||
|
a_full: torch.Tensor,
|
||||||
|
w1_full: torch.Tensor,
|
||||||
|
w2_full: torch.Tensor,
|
||||||
|
per_act_token: bool,
|
||||||
|
per_out_ch: bool,
|
||||||
|
):
|
||||||
|
uid = nvshmem_get_unique_id(
|
||||||
|
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
|
||||||
|
torch.distributed.broadcast(uid, src=0)
|
||||||
|
nvshmem_init(uid, pgi.rank, pgi.world_size)
|
||||||
|
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
|
torch_output = torch_moe2(a_full, w1_full, w2_full, topk_weights,
|
||||||
|
topk_ids)
|
||||||
|
pplx_output = pplx_cutlass_moe(pgi, dp_size, a, w1, w2, w1_scale,
|
||||||
|
w2_scale, topk_weights, topk_ids,
|
||||||
|
a1_scale, out_dtype, per_act_token,
|
||||||
|
per_out_ch)
|
||||||
|
|
||||||
|
torch_output = chunk_by_rank(torch_output, pgi.rank,
|
||||||
|
pgi.world_size).to(pplx_output.device)
|
||||||
|
|
||||||
|
# Uncomment if more debugging is needed
|
||||||
|
# print("PPLX OUT:", pplx_output)
|
||||||
|
# print("TORCH OUT:", torch_output)
|
||||||
|
|
||||||
|
torch.testing.assert_close(pplx_output, torch_output, atol=0.05, rtol=0)
|
||||||
|
|
||||||
|
nvshmem_finalize()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("m", [2, 224])
|
||||||
|
@pytest.mark.parametrize("n", [3072])
|
||||||
|
@pytest.mark.parametrize("k", [1536])
|
||||||
|
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||||
|
@pytest.mark.parametrize("topk", TOP_KS)
|
||||||
|
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||||
|
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||||
|
@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]])
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||||
|
current_platform.get_device_capability()),
|
||||||
|
reason="Grouped gemm is not supported on this GPU type.")
|
||||||
|
@requires_pplx
|
||||||
|
def test_cutlass_moe_pplx(
|
||||||
|
m: int,
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
e: int,
|
||||||
|
topk: int,
|
||||||
|
per_act_token: bool,
|
||||||
|
per_out_ch: bool,
|
||||||
|
world_dp_size: tuple[int, int],
|
||||||
|
):
|
||||||
|
current_platform.seed_everything(7)
|
||||||
|
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
|
|
||||||
|
dtype = torch.half
|
||||||
|
|
||||||
|
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10.0
|
||||||
|
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10.0
|
||||||
|
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10.0
|
||||||
|
|
||||||
|
n_b_scales = 2 * n if per_out_ch else 1
|
||||||
|
k_b_scales = k if per_out_ch else 1
|
||||||
|
|
||||||
|
w1_q = torch.empty((e, 2 * n, k),
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.float8_e4m3fn)
|
||||||
|
w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn)
|
||||||
|
w1_scale = torch.empty((e, n_b_scales, 1),
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.float32)
|
||||||
|
w2_scale = torch.empty((e, k_b_scales, 1),
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.float32)
|
||||||
|
|
||||||
|
for expert in range(e):
|
||||||
|
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
|
||||||
|
w1[expert], use_per_token_if_dynamic=per_out_ch)
|
||||||
|
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
|
||||||
|
w2[expert], use_per_token_if_dynamic=per_out_ch)
|
||||||
|
|
||||||
|
w1_d = torch.empty_like(w1)
|
||||||
|
w2_d = torch.empty_like(w2)
|
||||||
|
for expert in range(e):
|
||||||
|
w1_d[expert] = (w1_q[expert].float() * w1_scale[expert]).half()
|
||||||
|
w2_d[expert] = (w2_q[expert].float() * w2_scale[expert]).half()
|
||||||
|
|
||||||
|
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||||
|
topk_weights, topk_ids, _ = fused_topk(a,
|
||||||
|
score,
|
||||||
|
topk,
|
||||||
|
renormalize=False)
|
||||||
|
|
||||||
|
world_size, dp_size = world_dp_size
|
||||||
|
a_scale1 = torch.randn(
|
||||||
|
(m if per_act_token else 1, 1), device="cuda",
|
||||||
|
dtype=torch.float32) / 10.0
|
||||||
|
if not per_act_token:
|
||||||
|
a_scale1 = a_scale1.repeat(world_size, 1)
|
||||||
|
|
||||||
|
parallel_launch(world_size, _pplx_moe, dp_size, a, w1_q, w2_q,
|
||||||
|
w1_scale, w2_scale, topk_weights, topk_ids, a_scale1,
|
||||||
|
dtype, a, w1_d, w2_d, per_act_token, per_out_ch)
|
||||||
@ -4,10 +4,7 @@
|
|||||||
|
|
||||||
Run `pytest tests/kernels/test_pplx_moe.py`.
|
Run `pytest tests/kernels/test_pplx_moe.py`.
|
||||||
"""
|
"""
|
||||||
import dataclasses
|
from typing import Optional
|
||||||
import os
|
|
||||||
import traceback
|
|
||||||
from typing import Callable, Optional
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@ -21,10 +18,7 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
has_pplx = False
|
has_pplx = False
|
||||||
|
|
||||||
from torch.multiprocessing import (
|
from tests.pplx_utils import ProcessGroupInfo, parallel_launch
|
||||||
spawn) # pyright: ignore[reportPrivateImportUsage]
|
|
||||||
from typing_extensions import Concatenate, ParamSpec
|
|
||||||
|
|
||||||
from vllm.config import VllmConfig, set_current_vllm_config
|
from vllm.config import VllmConfig, set_current_vllm_config
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.fused_moe import override_config
|
from vllm.model_executor.layers.fused_moe import override_config
|
||||||
@ -36,6 +30,11 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
|||||||
FusedMoEModularKernel)
|
FusedMoEModularKernel)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
requires_pplx = pytest.mark.skipif(
|
||||||
|
not has_pplx,
|
||||||
|
reason="Requires PPLX kernels",
|
||||||
|
)
|
||||||
|
|
||||||
PPLX_PREPARE_COMBOS = [(4, 128, 128), (32, 1024, 512), (64, 1024, 512),
|
PPLX_PREPARE_COMBOS = [(4, 128, 128), (32, 1024, 512), (64, 1024, 512),
|
||||||
(222, 2048, 1024)]
|
(222, 2048, 1024)]
|
||||||
|
|
||||||
@ -57,122 +56,6 @@ vllm_config = VllmConfig()
|
|||||||
vllm_config.scheduler_config.max_num_seqs = 128
|
vllm_config.scheduler_config.max_num_seqs = 128
|
||||||
vllm_config.scheduler_config.max_model_len = 8192
|
vllm_config.scheduler_config.max_model_len = 8192
|
||||||
|
|
||||||
P = ParamSpec("P")
|
|
||||||
|
|
||||||
requires_pplx = pytest.mark.skipif(
|
|
||||||
not has_pplx,
|
|
||||||
reason="Requires PPLX kernels",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class ProcessGroupInfo:
|
|
||||||
world_size: int
|
|
||||||
world_local_size: int
|
|
||||||
rank: int
|
|
||||||
node_rank: int
|
|
||||||
local_rank: int
|
|
||||||
device: torch.device
|
|
||||||
|
|
||||||
|
|
||||||
def _worker_parallel_launch(
|
|
||||||
local_rank: int,
|
|
||||||
world_size: int,
|
|
||||||
world_local_size: int,
|
|
||||||
node_rank: int,
|
|
||||||
init_method: str,
|
|
||||||
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
|
|
||||||
*args: P.args,
|
|
||||||
**kwargs: P.kwargs,
|
|
||||||
) -> None:
|
|
||||||
rank = node_rank * world_local_size + local_rank
|
|
||||||
torch.cuda.set_device(local_rank)
|
|
||||||
device = torch.device("cuda", local_rank)
|
|
||||||
torch.distributed.init_process_group(
|
|
||||||
backend="cpu:gloo,cuda:nccl",
|
|
||||||
init_method=init_method,
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
device_id=device,
|
|
||||||
)
|
|
||||||
barrier = torch.tensor([rank], device=device)
|
|
||||||
torch.distributed.all_reduce(barrier)
|
|
||||||
|
|
||||||
try:
|
|
||||||
worker(
|
|
||||||
ProcessGroupInfo(
|
|
||||||
world_size=world_size,
|
|
||||||
world_local_size=world_local_size,
|
|
||||||
rank=rank,
|
|
||||||
node_rank=node_rank,
|
|
||||||
local_rank=local_rank,
|
|
||||||
device=device,
|
|
||||||
),
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
except Exception as ex:
|
|
||||||
print(ex)
|
|
||||||
traceback.print_exc()
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
torch.distributed.destroy_process_group()
|
|
||||||
|
|
||||||
|
|
||||||
def parallel_launch(
|
|
||||||
world_size: int,
|
|
||||||
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
|
|
||||||
*args: P.args,
|
|
||||||
**kwargs: P.kwargs,
|
|
||||||
) -> None:
|
|
||||||
assert not kwargs
|
|
||||||
spawn(
|
|
||||||
_worker_parallel_launch,
|
|
||||||
args=(
|
|
||||||
world_size,
|
|
||||||
world_size,
|
|
||||||
0,
|
|
||||||
"tcp://localhost:29500",
|
|
||||||
worker,
|
|
||||||
) + args,
|
|
||||||
nprocs=world_size,
|
|
||||||
join=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def parallel_launch_from_env(
|
|
||||||
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
|
|
||||||
*args: P.args,
|
|
||||||
**kwargs: P.kwargs,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Launches a worker function in parallel across all processes in the current
|
|
||||||
environment. The environment must have the following variables set:
|
|
||||||
- WORLD_SIZE: The total number of processes.
|
|
||||||
- WORLD_LOCAL_SIZE: The number of processes on the current node.
|
|
||||||
- NODE_RANK: The rank of the current
|
|
||||||
- MASTER_ADDR: The address of the master process.
|
|
||||||
- MASTER_PORT: The port of the master process.
|
|
||||||
"""
|
|
||||||
assert not kwargs
|
|
||||||
world_size = int(os.environ["WORLD_SIZE"])
|
|
||||||
world_local_size = int(os.environ["WORLD_LOCAL_SIZE"])
|
|
||||||
node_rank = int(os.environ["NODE_RANK"])
|
|
||||||
assert "MASTER_ADDR" in os.environ
|
|
||||||
assert "MASTER_PORT" in os.environ
|
|
||||||
spawn(
|
|
||||||
_worker_parallel_launch,
|
|
||||||
args=(
|
|
||||||
world_size,
|
|
||||||
world_local_size,
|
|
||||||
node_rank,
|
|
||||||
"env://",
|
|
||||||
worker,
|
|
||||||
) + args,
|
|
||||||
nprocs=world_local_size,
|
|
||||||
join=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def torch_prepare(
|
def torch_prepare(
|
||||||
a: torch.Tensor,
|
a: torch.Tensor,
|
||||||
|
|||||||
@ -632,7 +632,8 @@ def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool,
|
|||||||
ops.cutlass_moe_mm(out_tensors_stacked, a_tensors_stacked,
|
ops.cutlass_moe_mm(out_tensors_stacked, a_tensors_stacked,
|
||||||
b_tensors_stacked, a_scales_tensors_stacked,
|
b_tensors_stacked, a_scales_tensors_stacked,
|
||||||
b_scales_tensors_stacked, expert_offsets[:-1],
|
b_scales_tensors_stacked, expert_offsets[:-1],
|
||||||
problem_sizes, ab_strides, ab_strides, c_strides)
|
problem_sizes, ab_strides, ab_strides, c_strides,
|
||||||
|
per_act_token, per_out_ch)
|
||||||
|
|
||||||
# Validate each group's result against the baseline
|
# Validate each group's result against the baseline
|
||||||
for g in range(num_experts):
|
for g in range(num_experts):
|
||||||
|
|||||||
123
tests/pplx_utils.py
Normal file
123
tests/pplx_utils.py
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
import os
|
||||||
|
import traceback
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.multiprocessing import (
|
||||||
|
spawn) # pyright: ignore[reportPrivateImportUsage]
|
||||||
|
from typing_extensions import Concatenate, ParamSpec
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class ProcessGroupInfo:
|
||||||
|
world_size: int
|
||||||
|
world_local_size: int
|
||||||
|
rank: int
|
||||||
|
node_rank: int
|
||||||
|
local_rank: int
|
||||||
|
device: torch.device
|
||||||
|
|
||||||
|
|
||||||
|
def _worker_parallel_launch(
|
||||||
|
local_rank: int,
|
||||||
|
world_size: int,
|
||||||
|
world_local_size: int,
|
||||||
|
node_rank: int,
|
||||||
|
init_method: str,
|
||||||
|
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
|
||||||
|
*args: P.args,
|
||||||
|
**kwargs: P.kwargs,
|
||||||
|
) -> None:
|
||||||
|
rank = node_rank * world_local_size + local_rank
|
||||||
|
torch.cuda.set_device(local_rank)
|
||||||
|
device = torch.device("cuda", local_rank)
|
||||||
|
torch.distributed.init_process_group(
|
||||||
|
backend="cpu:gloo,cuda:nccl",
|
||||||
|
init_method=init_method,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
device_id=device,
|
||||||
|
)
|
||||||
|
barrier = torch.tensor([rank], device=device)
|
||||||
|
torch.distributed.all_reduce(barrier)
|
||||||
|
|
||||||
|
try:
|
||||||
|
worker(
|
||||||
|
ProcessGroupInfo(
|
||||||
|
world_size=world_size,
|
||||||
|
world_local_size=world_local_size,
|
||||||
|
rank=rank,
|
||||||
|
node_rank=node_rank,
|
||||||
|
local_rank=local_rank,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
except Exception as ex:
|
||||||
|
print(ex)
|
||||||
|
traceback.print_exc()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
torch.distributed.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
|
def parallel_launch(
|
||||||
|
world_size: int,
|
||||||
|
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
|
||||||
|
*args: P.args,
|
||||||
|
**kwargs: P.kwargs,
|
||||||
|
) -> None:
|
||||||
|
assert not kwargs
|
||||||
|
spawn(
|
||||||
|
_worker_parallel_launch,
|
||||||
|
args=(
|
||||||
|
world_size,
|
||||||
|
world_size,
|
||||||
|
0,
|
||||||
|
"tcp://localhost:29500",
|
||||||
|
worker,
|
||||||
|
) + args,
|
||||||
|
nprocs=world_size,
|
||||||
|
join=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parallel_launch_from_env(
|
||||||
|
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
|
||||||
|
*args: P.args,
|
||||||
|
**kwargs: P.kwargs,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Launches a worker function in parallel across all processes in the current
|
||||||
|
environment. The environment must have the following variables set:
|
||||||
|
- WORLD_SIZE: The total number of processes.
|
||||||
|
- WORLD_LOCAL_SIZE: The number of processes on the current node.
|
||||||
|
- NODE_RANK: The rank of the current
|
||||||
|
- MASTER_ADDR: The address of the master process.
|
||||||
|
- MASTER_PORT: The port of the master process.
|
||||||
|
"""
|
||||||
|
assert not kwargs
|
||||||
|
world_size = int(os.environ["WORLD_SIZE"])
|
||||||
|
world_local_size = int(os.environ["WORLD_LOCAL_SIZE"])
|
||||||
|
node_rank = int(os.environ["NODE_RANK"])
|
||||||
|
assert "MASTER_ADDR" in os.environ
|
||||||
|
assert "MASTER_PORT" in os.environ
|
||||||
|
spawn(
|
||||||
|
_worker_parallel_launch,
|
||||||
|
args=(
|
||||||
|
world_size,
|
||||||
|
world_local_size,
|
||||||
|
node_rank,
|
||||||
|
"env://",
|
||||||
|
worker,
|
||||||
|
) + args,
|
||||||
|
nprocs=world_local_size,
|
||||||
|
join=True,
|
||||||
|
)
|
||||||
@ -899,11 +899,36 @@ def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor):
|
|||||||
return output_tensor
|
return output_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def get_cutlass_pplx_moe_mm_data(expert_offsets: torch.Tensor,
|
||||||
|
problem_sizes1: torch.Tensor,
|
||||||
|
problem_sizes2: torch.Tensor,
|
||||||
|
expert_num_tokens: torch.Tensor,
|
||||||
|
num_local_experts: int, padded_m: int, n: int,
|
||||||
|
k: int):
|
||||||
|
"""
|
||||||
|
Prepare data necessary to perform CUTLASS grouped matrix multiplications
|
||||||
|
used in CUTLASS-based fused MoE.
|
||||||
|
|
||||||
|
The function takes in expert_num_tokens (token count per expert) and
|
||||||
|
non_zero_expert_idxs (consecutive indices of experts with non-zero token
|
||||||
|
counts) and uses them to compute:
|
||||||
|
- expert_offsets: Indices that mark at which token index each expert begins
|
||||||
|
its computation.
|
||||||
|
- problem_sizes1, problem_sizes2: MxNxK sizes of each expert's
|
||||||
|
multiplication in two grouped MMs used in
|
||||||
|
the fused MoE operation.
|
||||||
|
"""
|
||||||
|
return torch.ops._C.get_cutlass_pplx_moe_mm_data(
|
||||||
|
expert_offsets, problem_sizes1, problem_sizes2, expert_num_tokens,
|
||||||
|
num_local_experts, padded_m, n, k)
|
||||||
|
|
||||||
|
|
||||||
def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor,
|
def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor,
|
||||||
b_tensors: torch.Tensor, a_scales: torch.Tensor,
|
b_tensors: torch.Tensor, a_scales: torch.Tensor,
|
||||||
b_scales: torch.Tensor, expert_offsets: torch.Tensor,
|
b_scales: torch.Tensor, expert_offsets: torch.Tensor,
|
||||||
problem_sizes: torch.Tensor, a_strides: torch.Tensor,
|
problem_sizes: torch.Tensor, a_strides: torch.Tensor,
|
||||||
b_strides: torch.Tensor, c_strides: torch.Tensor):
|
b_strides: torch.Tensor, c_strides: torch.Tensor,
|
||||||
|
per_act_token: bool, per_out_ch: bool):
|
||||||
"""
|
"""
|
||||||
A single grouped matrix multiplication used in CUTLASS-based fused MoE.
|
A single grouped matrix multiplication used in CUTLASS-based fused MoE.
|
||||||
The function executes fp8-quantized OUT = AB matrix multiplication.
|
The function executes fp8-quantized OUT = AB matrix multiplication.
|
||||||
@ -918,7 +943,7 @@ def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor,
|
|||||||
return torch.ops._C.cutlass_moe_mm(out_tensors, a_tensors, b_tensors,
|
return torch.ops._C.cutlass_moe_mm(out_tensors, a_tensors, b_tensors,
|
||||||
a_scales, b_scales, expert_offsets,
|
a_scales, b_scales, expert_offsets,
|
||||||
problem_sizes, a_strides, b_strides,
|
problem_sizes, a_strides, b_strides,
|
||||||
c_strides)
|
c_strides, per_act_token, per_out_ch)
|
||||||
|
|
||||||
|
|
||||||
def cutlass_fp4_moe_mm(a_tensors: torch.Tensor, b_tensors: torch.Tensor,
|
def cutlass_fp4_moe_mm(a_tensors: torch.Tensor, b_tensors: torch.Tensor,
|
||||||
|
|||||||
@ -39,6 +39,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
def workspace_shapes(
|
def workspace_shapes(
|
||||||
self,
|
self,
|
||||||
a: torch.Tensor,
|
a: torch.Tensor,
|
||||||
|
aq: torch.Tensor,
|
||||||
M: int,
|
M: int,
|
||||||
N: int,
|
N: int,
|
||||||
K: int,
|
K: int,
|
||||||
|
|||||||
@ -67,6 +67,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
def workspace_shapes(
|
def workspace_shapes(
|
||||||
self,
|
self,
|
||||||
a: torch.Tensor,
|
a: torch.Tensor,
|
||||||
|
aq: torch.Tensor,
|
||||||
M: int,
|
M: int,
|
||||||
N: int,
|
N: int,
|
||||||
K: int,
|
K: int,
|
||||||
@ -78,11 +79,11 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
# even if we fall back to triton later, e.g. if expert maps are set.
|
# even if we fall back to triton later, e.g. if expert maps are set.
|
||||||
if self.allow_deep_gemm and self.batched_deep_gemm_experts is not None:
|
if self.allow_deep_gemm and self.batched_deep_gemm_experts is not None:
|
||||||
return self.batched_deep_gemm_experts.workspace_shapes(
|
return self.batched_deep_gemm_experts.workspace_shapes(
|
||||||
a, M, N, K, topk, num_experts)
|
a, aq, M, N, K, topk, num_experts)
|
||||||
else:
|
else:
|
||||||
assert self.batched_triton_experts is not None
|
assert self.batched_triton_experts is not None
|
||||||
return self.batched_triton_experts.workspace_shapes(
|
return self.batched_triton_experts.workspace_shapes(
|
||||||
a, M, N, K, topk, num_experts)
|
a, aq, M, N, K, topk, num_experts)
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
""" CUTLASS based Fused MoE kernels."""
|
""" CUTLASS based Fused MoE kernels."""
|
||||||
from typing import Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -13,56 +13,24 @@ from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache
|
|||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import scalar_types
|
||||||
|
|
||||||
|
|
||||||
class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
def run_cutlass_moe_fp8(
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
ab_strides1: torch.Tensor,
|
|
||||||
c_strides1: torch.Tensor,
|
|
||||||
ab_strides2: torch.Tensor,
|
|
||||||
c_strides2: torch.Tensor,
|
|
||||||
out_dtype: torch.dtype,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.ab_strides1 = ab_strides1
|
|
||||||
self.c_strides1 = c_strides1
|
|
||||||
self.ab_strides2 = ab_strides2
|
|
||||||
self.c_strides2 = c_strides2
|
|
||||||
self.out_dtype = out_dtype
|
|
||||||
|
|
||||||
def workspace_shapes(
|
|
||||||
self,
|
|
||||||
a: torch.Tensor,
|
|
||||||
M: int,
|
|
||||||
N: int,
|
|
||||||
K: int,
|
|
||||||
topk: int,
|
|
||||||
num_experts: int,
|
|
||||||
) -> tuple[int, int, torch.dtype]:
|
|
||||||
# Note that K, N are transposed
|
|
||||||
N, K = K, N
|
|
||||||
workspace1 = M * topk * max(2 * N, K)
|
|
||||||
workspace2 = M * topk * N
|
|
||||||
return (workspace1, workspace2, self.out_dtype)
|
|
||||||
|
|
||||||
def apply(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
activation: str,
|
activation_callable: Callable,
|
||||||
global_num_experts: int,
|
global_num_experts: int,
|
||||||
expert_map: Optional[torch.Tensor],
|
expert_map: Optional[torch.Tensor],
|
||||||
w1_scale: Optional[torch.Tensor],
|
w1_scale: Optional[torch.Tensor],
|
||||||
w2_scale: Optional[torch.Tensor],
|
w2_scale: Optional[torch.Tensor],
|
||||||
w1_zp: Optional[torch.Tensor],
|
|
||||||
w2_zp: Optional[torch.Tensor],
|
|
||||||
a1q_scale: Optional[torch.Tensor],
|
a1q_scale: Optional[torch.Tensor],
|
||||||
a2_scale: Optional[torch.Tensor],
|
a2_scale: Optional[torch.Tensor],
|
||||||
workspace13: torch.Tensor,
|
workspace13: torch.Tensor,
|
||||||
workspace2: torch.Tensor,
|
workspace2: torch.Tensor,
|
||||||
expert_num_tokens: Optional[torch.Tensor],
|
expert_num_tokens: Optional[torch.Tensor],
|
||||||
|
out_dtype: torch.dtype,
|
||||||
|
per_act_token: bool,
|
||||||
|
per_out_ch: bool,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
a1q = hidden_states
|
a1q = hidden_states
|
||||||
|
|
||||||
@ -70,38 +38,50 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
assert w2_scale is not None
|
assert w2_scale is not None
|
||||||
assert w1.dtype == torch.float8_e4m3fn
|
assert w1.dtype == torch.float8_e4m3fn
|
||||||
assert w2.dtype == torch.float8_e4m3fn
|
assert w2.dtype == torch.float8_e4m3fn
|
||||||
assert a1q.shape[1] == w1.shape[1], "Hidden size mismatch w1"
|
if expert_num_tokens is None:
|
||||||
assert w1.shape[2] == w2.shape[1] * 2, "Hidden size mismatch w2"
|
assert a1q.shape[1] == w1.shape[2], "Hidden size mismatch w1"
|
||||||
|
else:
|
||||||
|
assert a1q.shape[2] == w1.shape[2], "Hidden size mismatch w1"
|
||||||
|
assert w1.shape[1] == w2.shape[2] * 2, "Hidden size mismatch w2"
|
||||||
|
assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[
|
||||||
|
1] == w1.shape[1], "W1 scale shape mismatch"
|
||||||
|
assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[
|
||||||
|
1] == w2.shape[1], "W2 scale shape mismatch"
|
||||||
assert w1.shape[0] == w2.shape[0], "Expert number mismatch"
|
assert w1.shape[0] == w2.shape[0], "Expert number mismatch"
|
||||||
assert a1q_scale is None or a1q_scale.dim(
|
assert a1q_scale is None or a1q_scale.dim(
|
||||||
) == 0 or a1q_scale.shape[0] == 1 or a1q_scale.shape[0] == a1q.shape[
|
) == 0 or a1q_scale.shape[0] == 1 or a1q_scale.shape[0] == a1q.shape[
|
||||||
0], "Input scale shape mismatch"
|
0], "Input scale shape mismatch"
|
||||||
assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[
|
|
||||||
1] == w1.shape[2], "W1 scale shape mismatch"
|
|
||||||
assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[
|
|
||||||
1] == w2.shape[2], "W2 scale shape mismatch"
|
|
||||||
assert w1.shape[0] == w2.shape[0], "Weights expert number mismatch"
|
assert w1.shape[0] == w2.shape[0], "Weights expert number mismatch"
|
||||||
assert w1.shape[0] == w1_scale.shape[
|
assert w1.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
|
||||||
0], "w1 scales expert number mismatch"
|
assert w1.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
|
||||||
assert w1.shape[0] == w2_scale.shape[
|
assert a2_scale is None or a2_scale.dim(
|
||||||
0], "w2 scales expert number mismatch"
|
) == 0 or a2_scale.shape[0] == 1 or a2_scale.shape[0] == a1q.shape[
|
||||||
assert a2_scale is None or a1q_scale is None or a2_scale.shape == a1q_scale.shape, "Intermediate scale shape mismatch" # noqa: E501
|
0], "Intermediate scale shape mismatch"
|
||||||
assert self.ab_strides1.shape[0] == w1.shape[
|
assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
|
||||||
0], "AB Strides 1 expert number mismatch"
|
if expert_map is not None:
|
||||||
assert self.c_strides1.shape[0] == w1.shape[
|
assert expert_num_tokens is None
|
||||||
0], "C Strides 1 expert number mismatch"
|
|
||||||
assert self.ab_strides2.shape[0] == w2.shape[
|
|
||||||
0], "AB Strides 2 expert number mismatch"
|
|
||||||
assert self.c_strides2.shape[0] == w2.shape[
|
|
||||||
0], "C Strides 2 expert number mismatch"
|
|
||||||
assert self.out_dtype in [torch.half,
|
|
||||||
torch.bfloat16], "Invalid output dtype"
|
|
||||||
|
|
||||||
M = a1q.shape[0]
|
# We have two modes: PPLX and non-PPLX. We differentiate them by checking
|
||||||
_, N, K = w2.shape # because w1 + w2 are transposed
|
# if expert_num_tokens is None (expert_num_tokens is a tensor which PPLX
|
||||||
|
# uses to track the number of tokens per expert).
|
||||||
|
# In the non-PPLX mode, the input tokens are not padded: thus, the shape
|
||||||
|
# of the input is [total_num_tokens, hidden_size]. The input and output
|
||||||
|
# require shuffling by a_map and c_map such that the tokens assigned to
|
||||||
|
# each expert are contiguous.
|
||||||
|
# In the PPLX mode, the input tokens are padded per expert to ensure that
|
||||||
|
# the PPLX dispatch and combine functions work correctly: thus, the shape
|
||||||
|
# of the input is [num_experts, max_num_tokens_per_expert, hidden_size].
|
||||||
|
# The PPLX input and output require no shuffling by a_map and c_map since
|
||||||
|
# their tokens are already contiguous for each expert as a result of
|
||||||
|
# the dispatch function.
|
||||||
|
is_pplx = expert_num_tokens is not None
|
||||||
|
|
||||||
|
M = a1q.shape[0] # no pplx
|
||||||
|
padded_M = a1q.shape[1] # pplx
|
||||||
|
_, K, N = w2.shape
|
||||||
device = a1q.device
|
device = a1q.device
|
||||||
|
|
||||||
assert w1.shape[1] == K
|
assert w1.shape[2] == K
|
||||||
assert global_num_experts != -1
|
assert global_num_experts != -1
|
||||||
assert a1q_scale is not None
|
assert a1q_scale is not None
|
||||||
|
|
||||||
@ -113,10 +93,29 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
local_topk_ids = topk_ids
|
local_topk_ids = topk_ids
|
||||||
|
|
||||||
topk = local_topk_ids.shape[1]
|
topk = local_topk_ids.shape[1]
|
||||||
|
local_E = w1.shape[0]
|
||||||
|
|
||||||
per_act_token = a1q_scale.numel() != 1 if a1q_scale is not None else (
|
if is_pplx:
|
||||||
a2_scale.numel() != 1 if a2_scale is not None else False)
|
expert_offsets = torch.empty((local_E),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device)
|
||||||
|
problem_sizes1 = torch.empty((local_E, 3),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device)
|
||||||
|
problem_sizes2 = torch.empty((local_E, 3),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device)
|
||||||
|
|
||||||
|
ops.get_cutlass_pplx_moe_mm_data(expert_offsets, problem_sizes1,
|
||||||
|
problem_sizes2, expert_num_tokens,
|
||||||
|
local_E, padded_M, N, K)
|
||||||
|
|
||||||
|
w1_scale = w1_scale.reshape(w1_scale.shape[0], -1)
|
||||||
|
w2_scale = w2_scale.reshape(w2_scale.shape[0], -1)
|
||||||
|
a1q = a1q.reshape(-1, a1q.shape[2])
|
||||||
|
a1q_scale = a1q_scale.reshape(-1, a1q_scale.shape[2]).contiguous()
|
||||||
|
|
||||||
|
else:
|
||||||
expert_offsets = torch.empty((global_num_experts + 1),
|
expert_offsets = torch.empty((global_num_experts + 1),
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=device)
|
device=device)
|
||||||
@ -149,16 +148,39 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
|
|
||||||
a1q = _fp8_perm(a1q, a_map)
|
a1q = _fp8_perm(a1q, a_map)
|
||||||
a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale
|
a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale
|
||||||
|
expert_offsets = expert_offsets[:-1]
|
||||||
|
|
||||||
|
ab_strides1 = torch.full((w1.shape[0], ),
|
||||||
|
K,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64)
|
||||||
|
c_strides1 = torch.full((w1.shape[0], ),
|
||||||
|
2 * N,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64)
|
||||||
|
ab_strides2 = torch.full((w1.shape[0], ),
|
||||||
|
N,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64)
|
||||||
|
c_strides2 = torch.full((w1.shape[0], ),
|
||||||
|
K,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64)
|
||||||
|
|
||||||
|
if is_pplx:
|
||||||
|
c1 = _resize_cache(workspace13, (local_E * padded_M, N * 2))
|
||||||
|
c2 = _resize_cache(workspace2, (local_E * padded_M, N))
|
||||||
|
c3 = _resize_cache(workspace13, (local_E * padded_M, K))
|
||||||
|
else:
|
||||||
c1 = _resize_cache(workspace13, (M * topk, N * 2))
|
c1 = _resize_cache(workspace13, (M * topk, N * 2))
|
||||||
c2 = _resize_cache(workspace2, (M * topk, N))
|
c2 = _resize_cache(workspace2, (M * topk, N))
|
||||||
c3 = _resize_cache(workspace13, (M * topk, K))
|
c3 = _resize_cache(workspace13, (M * topk, K))
|
||||||
|
|
||||||
ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale,
|
ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale, expert_offsets,
|
||||||
expert_offsets[:-1], problem_sizes1,
|
problem_sizes1, ab_strides1, ab_strides1, c_strides1,
|
||||||
self.ab_strides1, self.ab_strides1, self.c_strides1)
|
per_act_token, per_out_ch)
|
||||||
|
|
||||||
self.activation(activation, c2, c1)
|
activation_callable(c2, c1)
|
||||||
|
|
||||||
a2q, a2q_scale = ops.scaled_fp8_quant(
|
a2q, a2q_scale = ops.scaled_fp8_quant(
|
||||||
c2, a2_scale, use_per_token_if_dynamic=per_act_token)
|
c2, a2_scale, use_per_token_if_dynamic=per_act_token)
|
||||||
@ -166,33 +188,90 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
if expert_map is not None:
|
if expert_map is not None:
|
||||||
c3.fill_(0)
|
c3.fill_(0)
|
||||||
|
|
||||||
ops.cutlass_moe_mm(c3, a2q, w2, a2q_scale, w2_scale,
|
ops.cutlass_moe_mm(c3, a2q, w2, a2q_scale, w2_scale, expert_offsets,
|
||||||
expert_offsets[:-1], problem_sizes2,
|
problem_sizes2, ab_strides2, ab_strides2, c_strides2,
|
||||||
self.ab_strides2, self.ab_strides2, self.c_strides2)
|
per_act_token, per_out_ch)
|
||||||
|
|
||||||
c3 = c3[c_map]
|
if is_pplx:
|
||||||
|
return c3.reshape(local_E, padded_M, K)
|
||||||
return c3
|
else:
|
||||||
|
return c3[c_map].view(M, topk, K)
|
||||||
|
|
||||||
|
|
||||||
|
class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_experts_per_worker: int,
|
||||||
|
out_dtype: torch.dtype,
|
||||||
|
per_act_token: bool,
|
||||||
|
per_out_ch: bool,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.max_experts_per_worker = max_experts_per_worker
|
||||||
|
self.out_dtype = out_dtype
|
||||||
|
self.per_act_token = per_act_token
|
||||||
|
self.per_out_ch = per_out_ch
|
||||||
|
|
||||||
|
def workspace_shapes(
|
||||||
|
self,
|
||||||
|
a: torch.Tensor,
|
||||||
|
aq: torch.Tensor,
|
||||||
|
M: int,
|
||||||
|
N: int,
|
||||||
|
K: int,
|
||||||
|
topk: int,
|
||||||
|
num_experts: int,
|
||||||
|
) -> tuple[int, int, torch.dtype]:
|
||||||
|
padded_M = aq.shape[1]
|
||||||
|
workspace1 = self.max_experts_per_worker * padded_M * max(N, K)
|
||||||
|
workspace2 = self.max_experts_per_worker * padded_M * (N // 2)
|
||||||
|
return (workspace1, workspace2, self.out_dtype)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
activation: str,
|
||||||
|
global_num_experts: int,
|
||||||
|
expert_map: Optional[torch.Tensor],
|
||||||
|
w1_scale: Optional[torch.Tensor],
|
||||||
|
w2_scale: Optional[torch.Tensor],
|
||||||
|
w1_zp: Optional[torch.Tensor],
|
||||||
|
w2_zp: Optional[torch.Tensor],
|
||||||
|
a1q_scale: Optional[torch.Tensor],
|
||||||
|
a2_scale: Optional[torch.Tensor],
|
||||||
|
workspace13: torch.Tensor,
|
||||||
|
workspace2: torch.Tensor,
|
||||||
|
expert_num_tokens: Optional[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
|
||||||
|
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
|
||||||
|
activation_callable = lambda i, o: self.activation(activation, i, o)
|
||||||
|
return run_cutlass_moe_fp8(hidden_states, w1, w2, topk_ids,
|
||||||
|
activation_callable, global_num_experts,
|
||||||
|
expert_map, w1_scale, w2_scale, a1q_scale,
|
||||||
|
a2_scale, workspace13, workspace2,
|
||||||
|
expert_num_tokens, self.out_dtype,
|
||||||
|
self.per_act_token, self.per_out_ch)
|
||||||
|
|
||||||
|
|
||||||
#TODO make the grouped gemm kernel consistent with scaled gemm kernel
|
|
||||||
def cutlass_moe_fp8(
|
def cutlass_moe_fp8(
|
||||||
a: torch.Tensor,
|
a: torch.Tensor,
|
||||||
w1_q: torch.Tensor,
|
w1_q: torch.Tensor,
|
||||||
w2_q: torch.Tensor,
|
w2_q: torch.Tensor,
|
||||||
w1_scale: torch.Tensor,
|
|
||||||
w2_scale: torch.Tensor,
|
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
ab_strides1: torch.Tensor,
|
w1_scale: torch.Tensor,
|
||||||
c_strides1: torch.Tensor,
|
w2_scale: torch.Tensor,
|
||||||
ab_strides2: torch.Tensor,
|
activation: str = "silu",
|
||||||
c_strides2: torch.Tensor,
|
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
out_dtype: torch.dtype = torch.half,
|
|
||||||
expert_map: Optional[torch.Tensor] = None,
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
|
global_num_experts: int = -1,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
|
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
|
||||||
@ -207,25 +286,17 @@ def cutlass_moe_fp8(
|
|||||||
Shape: [num_experts, K, 2N] (the weights are passed transposed)
|
Shape: [num_experts, K, 2N] (the weights are passed transposed)
|
||||||
- w2_q (torch.Tensor): The second set of fp8-quantized expert weights.
|
- w2_q (torch.Tensor): The second set of fp8-quantized expert weights.
|
||||||
Shape: [num_experts, N, K] (the weights are passed transposed)
|
Shape: [num_experts, N, K] (the weights are passed transposed)
|
||||||
|
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
|
||||||
|
- topk_ids (torch.Tensor): The token->expert mappings.
|
||||||
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
|
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
|
||||||
Shape: [num_experts] or [num_experts, 2N]
|
Shape: [num_experts] or [num_experts, 2N]
|
||||||
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
|
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
|
||||||
Shape: [num_experts] or [num_experts, K]
|
Shape: [num_experts] or [num_experts, K]
|
||||||
- gating_output (torch.Tensor): The output of the gating operation
|
|
||||||
(before softmax).
|
|
||||||
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
|
|
||||||
- ab_strides1 (torch.Tensor): The input and weights strides of the first
|
|
||||||
grouped gemm.
|
|
||||||
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
|
|
||||||
- ab_strides2 (torch.Tensor): The input and weights strides of the second
|
|
||||||
grouped gemm.
|
|
||||||
- c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
|
|
||||||
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
|
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
|
||||||
Shape: scalar or [M]
|
Shape: scalar or [M]
|
||||||
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
|
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
|
||||||
quantize the intermediate result between the gemms.
|
quantize the intermediate result between the gemms.
|
||||||
Shape: scalar or [M]
|
Shape: scalar or [M]
|
||||||
- out_dtype (torch.dtype): The output tensor type.
|
|
||||||
- expert_map (Optional[torch.Tensor]): In the case of Expert parallel,
|
- expert_map (Optional[torch.Tensor]): In the case of Expert parallel,
|
||||||
every Rank is responsible for a subset of experts. expert_map is a
|
every Rank is responsible for a subset of experts. expert_map is a
|
||||||
mapping from global expert-id to local expert-id. When expert_map[i]
|
mapping from global expert-id to local expert-id. When expert_map[i]
|
||||||
@ -233,24 +304,27 @@ def cutlass_moe_fp8(
|
|||||||
expert-id i.
|
expert-id i.
|
||||||
- apply_router_weight_on_input (bool): When true, the topk weights are
|
- apply_router_weight_on_input (bool): When true, the topk weights are
|
||||||
applied directly on the inputs. This is only applicable when topk is 1.
|
applied directly on the inputs. This is only applicable when topk is 1.
|
||||||
|
- global_num_experts (int): The total number of experts.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
|
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
|
||||||
"""
|
"""
|
||||||
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
|
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
|
||||||
a2_scale.numel() != 1 if a2_scale is not None else False)
|
a2_scale.numel() != 1 if a2_scale is not None else False)
|
||||||
|
per_out_ch = w1_scale.numel() != w1_q.shape[0]
|
||||||
|
|
||||||
|
out_dtype = a.dtype
|
||||||
|
|
||||||
fn = mk.FusedMoEModularKernel(
|
fn = mk.FusedMoEModularKernel(
|
||||||
MoEPrepareAndFinalizeNoEP(
|
MoEPrepareAndFinalizeNoEP(
|
||||||
per_channel_quant=per_act_token,
|
|
||||||
quant_dtype=torch.float8_e4m3fn,
|
quant_dtype=torch.float8_e4m3fn,
|
||||||
|
per_channel_quant=per_act_token,
|
||||||
),
|
),
|
||||||
CutlassExpertsFp8(
|
CutlassExpertsFp8(
|
||||||
ab_strides1,
|
max_experts_per_worker=global_num_experts,
|
||||||
c_strides1,
|
out_dtype=out_dtype,
|
||||||
ab_strides2,
|
per_act_token=per_act_token,
|
||||||
c_strides2,
|
per_out_ch=per_out_ch,
|
||||||
out_dtype,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -260,9 +334,12 @@ def cutlass_moe_fp8(
|
|||||||
w2_q,
|
w2_q,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
expert_map=expert_map,
|
False,
|
||||||
w1_scale=w1_scale,
|
activation,
|
||||||
w2_scale=w2_scale,
|
global_num_experts if global_num_experts != -1 else w1_q.size(0),
|
||||||
|
expert_map,
|
||||||
|
w1_scale,
|
||||||
|
w2_scale,
|
||||||
a1_scale=a1_scale,
|
a1_scale=a1_scale,
|
||||||
a2_scale=a2_scale,
|
a2_scale=a2_scale,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
|||||||
@ -73,6 +73,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
def workspace_shapes(
|
def workspace_shapes(
|
||||||
self,
|
self,
|
||||||
a: torch.Tensor,
|
a: torch.Tensor,
|
||||||
|
aq: torch.Tensor,
|
||||||
M: int,
|
M: int,
|
||||||
N: int,
|
N: int,
|
||||||
K: int,
|
K: int,
|
||||||
|
|||||||
@ -521,6 +521,7 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
def workspace_shapes(
|
def workspace_shapes(
|
||||||
self,
|
self,
|
||||||
a: torch.Tensor,
|
a: torch.Tensor,
|
||||||
|
aq: torch.Tensor,
|
||||||
M: int,
|
M: int,
|
||||||
N: int,
|
N: int,
|
||||||
K: int,
|
K: int,
|
||||||
@ -632,6 +633,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
def workspace_shapes(
|
def workspace_shapes(
|
||||||
self,
|
self,
|
||||||
a: torch.Tensor,
|
a: torch.Tensor,
|
||||||
|
aq: torch.Tensor,
|
||||||
M: int,
|
M: int,
|
||||||
N: int,
|
N: int,
|
||||||
K: int,
|
K: int,
|
||||||
|
|||||||
@ -1545,6 +1545,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
def workspace_shapes(
|
def workspace_shapes(
|
||||||
self,
|
self,
|
||||||
a: torch.Tensor,
|
a: torch.Tensor,
|
||||||
|
aq: torch.Tensor,
|
||||||
M: int,
|
M: int,
|
||||||
N: int,
|
N: int,
|
||||||
K: int,
|
K: int,
|
||||||
|
|||||||
@ -9,6 +9,9 @@ from typing import Callable, Optional, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from compressed_tensors.quantization import (QuantizationArgs,
|
||||||
|
QuantizationStrategy,
|
||||||
|
QuantizationType)
|
||||||
from torch.nn.parameter import UninitializedParameter
|
from torch.nn.parameter import UninitializedParameter
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
@ -210,6 +213,7 @@ class MoEConfig:
|
|||||||
moe_parallel_config: FusedMoEParallelConfig
|
moe_parallel_config: FusedMoEParallelConfig
|
||||||
|
|
||||||
in_dtype: torch.dtype # The activation type.
|
in_dtype: torch.dtype # The activation type.
|
||||||
|
quant_dtype: torch.dtype = None
|
||||||
|
|
||||||
# TODO: add more quantization params, blocked, per-token, etc.
|
# TODO: add more quantization params, blocked, per-token, etc.
|
||||||
block_size: int = 128
|
block_size: int = 128
|
||||||
@ -264,8 +268,22 @@ class FusedMoeWeightScaleSupported(Enum):
|
|||||||
BLOCK = "block"
|
BLOCK = "block"
|
||||||
|
|
||||||
|
|
||||||
|
def get_quant_config_input_activations(
|
||||||
|
quant_config: Optional[QuantizationConfig]
|
||||||
|
) -> Optional[QuantizationArgs]:
|
||||||
|
if (quant_config is not None and hasattr(quant_config, 'target_scheme_map')
|
||||||
|
and "Linear" in quant_config.target_scheme_map and
|
||||||
|
"input_activations" in quant_config.target_scheme_map["Linear"]):
|
||||||
|
return quant_config.target_scheme_map["Linear"].get(
|
||||||
|
"input_activations")
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class FusedMoEMethodBase(QuantizeMethodBase):
|
class FusedMoEMethodBase(QuantizeMethodBase):
|
||||||
|
|
||||||
|
moe: MoEConfig
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||||
hidden_size: int, intermediate_size_per_partition: int,
|
hidden_size: int, intermediate_size_per_partition: int,
|
||||||
@ -277,6 +295,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||||
assert all2all_manager is not None
|
assert all2all_manager is not None
|
||||||
|
|
||||||
|
self.moe = moe
|
||||||
quant_dtype = None
|
quant_dtype = None
|
||||||
act_quant_block_size = None
|
act_quant_block_size = None
|
||||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||||
@ -297,12 +316,13 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
# dp_size actually means tp_size, bug in pplx kernels
|
# dp_size actually means tp_size, bug in pplx kernels
|
||||||
dp_size=all2all_manager.tp_group.world_size,
|
dp_size=all2all_manager.tp_group.world_size,
|
||||||
hidden_dim=moe.hidden_dim,
|
hidden_dim=moe.hidden_dim,
|
||||||
hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize,
|
hidden_dim_bytes=moe.hidden_dim * moe.quant_dtype.itemsize,
|
||||||
# For blocked per token: set to
|
# For blocked per token: set to
|
||||||
# ceil_div(hidden_dim, block_size) * sizeof(float32)
|
# ceil_div(hidden_dim, block_size) * sizeof(float32)
|
||||||
# For per-token: set to sizeof(float32)
|
# For per-token: set to sizeof(float32)
|
||||||
hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else (
|
hidden_dim_scale_bytes=(
|
||||||
(moe.hidden_dim + moe.block_size - 1) // moe.block_size *
|
0 if moe.quant_dtype.itemsize != 1 else
|
||||||
|
((moe.hidden_dim + moe.block_size - 1) // moe.block_size *
|
||||||
torch.float32.itemsize)),
|
torch.float32.itemsize)),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -313,6 +333,9 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
|
|
||||||
handle = all2all_manager.get_handle(all_to_all_args)
|
handle = all2all_manager.get_handle(all_to_all_args)
|
||||||
|
|
||||||
|
input_activations = get_quant_config_input_activations(
|
||||||
|
quant_config)
|
||||||
|
|
||||||
prepare_finalize = PplxPrepareAndFinalize(
|
prepare_finalize = PplxPrepareAndFinalize(
|
||||||
handle,
|
handle,
|
||||||
max_num_tokens=moe.max_num_tokens,
|
max_num_tokens=moe.max_num_tokens,
|
||||||
@ -320,7 +343,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
rank=all2all_manager.rank,
|
rank=all2all_manager.rank,
|
||||||
# dp_size actually means tp_size, bug in pplx kernels
|
# dp_size actually means tp_size, bug in pplx kernels
|
||||||
dp_size=all2all_manager.tp_group.world_size,
|
dp_size=all2all_manager.tp_group.world_size,
|
||||||
quant_dtype=moe.in_dtype,
|
quant_dtype=moe.quant_dtype,
|
||||||
|
per_act_token=(input_activations.strategy
|
||||||
|
== QuantizationStrategy.TOKEN
|
||||||
|
if input_activations is not None else False),
|
||||||
)
|
)
|
||||||
elif moe.use_deepep_ht_kernels:
|
elif moe.use_deepep_ht_kernels:
|
||||||
assert moe.dp_size == all2all_manager.dp_world_size
|
assert moe.dp_size == all2all_manager.dp_world_size
|
||||||
@ -365,15 +391,15 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
self.topk_indices_dtype = None
|
self.topk_indices_dtype = None
|
||||||
if prepare_finalize is not None:
|
if prepare_finalize is not None:
|
||||||
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
|
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
|
||||||
experts = self.select_gemm_impl(prepare_finalize)
|
experts = self.select_gemm_impl(prepare_finalize, moe)
|
||||||
self.fused_experts = FusedMoEModularKernel(
|
self.fused_experts = FusedMoEModularKernel(
|
||||||
prepare_finalize,
|
prepare_finalize,
|
||||||
experts,
|
experts,
|
||||||
)
|
)
|
||||||
|
|
||||||
def select_gemm_impl(
|
def select_gemm_impl(
|
||||||
self, prepare_finalize: FusedMoEPrepareAndFinalize
|
self, prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||||
) -> FusedMoEPermuteExpertsUnpermute:
|
moe: Optional[MoEConfig]) -> FusedMoEPermuteExpertsUnpermute:
|
||||||
# based on the all2all implementation, select the appropriate
|
# based on the all2all implementation, select the appropriate
|
||||||
# gemm implementation
|
# gemm implementation
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@ -419,7 +445,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
else:
|
else:
|
||||||
self.rocm_aiter_fused_experts = None # type: ignore
|
self.rocm_aiter_fused_experts = None # type: ignore
|
||||||
|
|
||||||
def select_gemm_impl(self, prepare_finalize: FusedMoEPrepareAndFinalize):
|
def select_gemm_impl(self, prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||||
|
moe: Optional[MoEConfig]):
|
||||||
|
|
||||||
assert self.fused_experts == fused_experts
|
assert self.fused_experts == fused_experts
|
||||||
|
|
||||||
@ -809,7 +836,6 @@ class FusedMoE(torch.nn.Module):
|
|||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if params_dtype is None:
|
if params_dtype is None:
|
||||||
params_dtype = torch.get_default_dtype()
|
params_dtype = torch.get_default_dtype()
|
||||||
self.params_dtype = params_dtype
|
self.params_dtype = params_dtype
|
||||||
@ -869,14 +895,24 @@ class FusedMoE(torch.nn.Module):
|
|||||||
from vllm_hpu_extension.ops import DynamicFusedMOE
|
from vllm_hpu_extension.ops import DynamicFusedMOE
|
||||||
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
|
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
|
||||||
|
|
||||||
|
# Only support float8 for now.
|
||||||
|
quant_dtype = params_dtype
|
||||||
|
if quant_config is not None:
|
||||||
|
input_activations = get_quant_config_input_activations(
|
||||||
|
quant_config)
|
||||||
|
if (input_activations is not None
|
||||||
|
and input_activations.num_bits == 8
|
||||||
|
and input_activations.type == QuantizationType.FLOAT):
|
||||||
|
quant_dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
moe = MoEConfig(
|
moe = MoEConfig(
|
||||||
num_experts=self.global_num_experts,
|
num_experts=self.global_num_experts,
|
||||||
experts_per_token=top_k,
|
experts_per_token=top_k,
|
||||||
hidden_dim=hidden_size,
|
hidden_dim=hidden_size,
|
||||||
num_local_experts=self.local_num_experts,
|
num_local_experts=self.local_num_experts,
|
||||||
moe_parallel_config=self.moe_parallel_config,
|
moe_parallel_config=self.moe_parallel_config,
|
||||||
# TODO (bnell): this needs to be fixed for quantized types.
|
|
||||||
in_dtype=params_dtype,
|
in_dtype=params_dtype,
|
||||||
|
quant_dtype=quant_dtype,
|
||||||
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
max_num_tokens=MOE_DP_CHUNK_SIZE,
|
||||||
)
|
)
|
||||||
self.moe_config = moe
|
self.moe_config = moe
|
||||||
|
|||||||
@ -175,6 +175,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
|||||||
def workspace_shapes(
|
def workspace_shapes(
|
||||||
self,
|
self,
|
||||||
a: torch.Tensor,
|
a: torch.Tensor,
|
||||||
|
aq: torch.Tensor,
|
||||||
M: int,
|
M: int,
|
||||||
N: int,
|
N: int,
|
||||||
K: int,
|
K: int,
|
||||||
@ -309,7 +310,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
|
|
||||||
# Use a1 here to decipher the correct workspace datatype
|
# Use a1 here to decipher the correct workspace datatype
|
||||||
workspace13_shape, workspace2_shape, workspace_dtype = (
|
workspace13_shape, workspace2_shape, workspace_dtype = (
|
||||||
self.fused_experts.workspace_shapes(a1, M, N, K, top_k,
|
self.fused_experts.workspace_shapes(a1, a1q, M, N, K, top_k,
|
||||||
global_num_experts))
|
global_num_experts))
|
||||||
|
|
||||||
# We can reuse the memory between cache1 and cache3 because by the time
|
# We can reuse the memory between cache1 and cache3 because by the time
|
||||||
|
|||||||
@ -21,7 +21,8 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
rank: int,
|
rank: int,
|
||||||
dp_size: int,
|
dp_size: int,
|
||||||
quant_dtype: Optional[torch.dtype] = None,
|
quant_dtype: Optional[torch.dtype] = None,
|
||||||
block_shape: Optional[list[int]] = None):
|
block_shape: Optional[list[int]] = None,
|
||||||
|
per_act_token: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert max_num_tokens > 0
|
assert max_num_tokens > 0
|
||||||
self.a2a = a2a
|
self.a2a = a2a
|
||||||
@ -31,6 +32,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.dp_size = dp_size
|
self.dp_size = dp_size
|
||||||
self.quant_dtype = quant_dtype
|
self.quant_dtype = quant_dtype
|
||||||
|
self.per_act_token = per_act_token
|
||||||
|
|
||||||
def max_num_tokens_per_rank(self) -> Optional[int]:
|
def max_num_tokens_per_rank(self) -> Optional[int]:
|
||||||
return self.max_num_tokens
|
return self.max_num_tokens
|
||||||
@ -66,13 +68,14 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
"apply_router_weight_on_input is only implemented for topk=1")
|
"apply_router_weight_on_input is only implemented for topk=1")
|
||||||
a1 = a1 * rank_topk_weights.to(a1.dtype)
|
a1 = a1 * rank_topk_weights.to(a1.dtype)
|
||||||
|
|
||||||
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
|
repeat_cols = 4
|
||||||
a2_scale.numel() != 1 if a2_scale is not None else False)
|
repeat_rows = 1 if self.per_act_token else a1.shape[0]
|
||||||
|
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||||
|
a1, (None if self.per_act_token else a1_scale), self.quant_dtype,
|
||||||
|
self.per_act_token, self.block_shape)
|
||||||
|
|
||||||
a1q, a1q_scale = moe_kernel_quantize_input(a1, a1_scale,
|
if a1q_scale is not None:
|
||||||
self.quant_dtype,
|
a1q_scale = a1q_scale.repeat(repeat_rows, repeat_cols)
|
||||||
per_act_token,
|
|
||||||
self.block_shape)
|
|
||||||
|
|
||||||
# rem_experts need to be 0 for pplx to work properly.
|
# rem_experts need to be 0 for pplx to work properly.
|
||||||
rem_experts = num_experts % self.world_size
|
rem_experts = num_experts % self.world_size
|
||||||
@ -100,7 +103,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
else 1) * float32_size
|
else 1) * float32_size
|
||||||
expert_x_scale = torch.empty(
|
expert_x_scale = torch.empty(
|
||||||
(
|
(
|
||||||
num_experts,
|
num_local_experts,
|
||||||
expert_x.size(1),
|
expert_x.size(1),
|
||||||
(expert_x.size(2) + block_size - 1) // block_size,
|
(expert_x.size(2) + block_size - 1) // block_size,
|
||||||
),
|
),
|
||||||
@ -121,6 +124,8 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
indices=rank_topk_ids,
|
indices=rank_topk_ids,
|
||||||
bound_m=bound_m,
|
bound_m=bound_m,
|
||||||
)
|
)
|
||||||
|
if expert_x_scale is not None:
|
||||||
|
expert_x_scale = expert_x_scale[:, :, 0:1]
|
||||||
|
|
||||||
return expert_x, expert_x_scale, expert_num_tokens, None, None
|
return expert_x, expert_x_scale, expert_num_tokens, None, None
|
||||||
|
|
||||||
|
|||||||
@ -37,6 +37,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
def workspace_shapes(
|
def workspace_shapes(
|
||||||
self,
|
self,
|
||||||
a: torch.Tensor,
|
a: torch.Tensor,
|
||||||
|
aq: torch.Tensor,
|
||||||
M: int,
|
M: int,
|
||||||
N: int,
|
N: int,
|
||||||
K: int,
|
K: int,
|
||||||
@ -49,9 +50,9 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K):
|
if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K):
|
||||||
assert self.deep_gemm_expert is not None
|
assert self.deep_gemm_expert is not None
|
||||||
return self.deep_gemm_expert.workspace_shapes(
|
return self.deep_gemm_expert.workspace_shapes(
|
||||||
a, M, N, K, topk, num_experts)
|
a, aq, M, N, K, topk, num_experts)
|
||||||
else:
|
else:
|
||||||
return self.triton_expert.workspace_shapes(a, M, N, K, topk,
|
return self.triton_expert.workspace_shapes(a, aq, M, N, K, topk,
|
||||||
num_experts)
|
num_experts)
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
|
import importlib
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
@ -11,7 +12,6 @@ from compressed_tensors.quantization import (ActivationOrdering,
|
|||||||
QuantizationStrategy)
|
QuantizationStrategy)
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
import vllm.model_executor.layers.fused_moe # noqa
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||||
@ -30,6 +30,15 @@ from vllm.model_executor.utils import set_weight_attrs
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import scalar_types
|
||||||
|
|
||||||
|
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
|
||||||
|
|
||||||
|
if current_platform.is_cuda_alike():
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||||
|
BatchedPrepareAndFinalize)
|
||||||
|
if has_pplx:
|
||||||
|
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
||||||
|
PplxPrepareAndFinalize)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -77,8 +86,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
|||||||
else:
|
else:
|
||||||
logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
|
logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
|
||||||
return CompressedTensorsWNA16MarlinMoEMethod(quant_config)
|
return CompressedTensorsWNA16MarlinMoEMethod(quant_config)
|
||||||
elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
|
elif quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant):
|
||||||
and layer.activation == "silu"):
|
|
||||||
return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config)
|
return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config)
|
||||||
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
|
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
|
||||||
return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
|
return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
|
||||||
@ -421,6 +429,11 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
|
|||||||
"For FP8 Fused MoE layer, we require either per tensor or "
|
"For FP8 Fused MoE layer, we require either per tensor or "
|
||||||
"channelwise, dynamic per token quantization.")
|
"channelwise, dynamic per token quantization.")
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||||
|
cutlass_moe_fp8)
|
||||||
|
self.fused_experts = cutlass_moe_fp8 # type: ignore
|
||||||
|
self.disable_expert_map = False
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||||
hidden_size: int, intermediate_size_per_partition: int,
|
hidden_size: int, intermediate_size_per_partition: int,
|
||||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||||
@ -499,25 +512,6 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
|
|||||||
layer.w13_input_scale = None
|
layer.w13_input_scale = None
|
||||||
layer.w2_input_scale = None
|
layer.w2_input_scale = None
|
||||||
|
|
||||||
device = w13_weight.device
|
|
||||||
# TODO strides can be shared across multiple layers
|
|
||||||
self.ab_strides1 = torch.full((num_experts, ),
|
|
||||||
hidden_size,
|
|
||||||
device=device,
|
|
||||||
dtype=torch.int64)
|
|
||||||
self.c_strides1 = torch.full((num_experts, ),
|
|
||||||
2 * intermediate_size_per_partition,
|
|
||||||
device=device,
|
|
||||||
dtype=torch.int64)
|
|
||||||
self.ab_strides2 = torch.full((num_experts, ),
|
|
||||||
intermediate_size_per_partition,
|
|
||||||
device=device,
|
|
||||||
dtype=torch.int64)
|
|
||||||
self.c_strides2 = torch.full((num_experts, ),
|
|
||||||
hidden_size,
|
|
||||||
device=device,
|
|
||||||
dtype=torch.int64)
|
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
# Fp8 moe kernels require a single activation scale.
|
# Fp8 moe kernels require a single activation scale.
|
||||||
# We take the max of all the scales in case they differ.
|
# We take the max of all the scales in case they differ.
|
||||||
@ -558,6 +552,27 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
|
|||||||
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
|
def select_gemm_impl(self, prepare_finalize, moe):
|
||||||
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||||
|
CutlassExpertsFp8)
|
||||||
|
|
||||||
|
assert moe is not None
|
||||||
|
|
||||||
|
max_experts_per_worker = (
|
||||||
|
(moe.num_experts + prepare_finalize.world_size - 1) //
|
||||||
|
prepare_finalize.world_size)
|
||||||
|
experts = CutlassExpertsFp8(
|
||||||
|
max_experts_per_worker, moe.in_dtype,
|
||||||
|
self.input_quant.strategy == QuantizationStrategy.TOKEN,
|
||||||
|
self.weight_quant.strategy == QuantizationStrategy.CHANNEL)
|
||||||
|
|
||||||
|
if has_pplx and isinstance(
|
||||||
|
prepare_finalize,
|
||||||
|
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
|
||||||
|
# no expert_map support in this case
|
||||||
|
self.disable_expert_map = True
|
||||||
|
return experts
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@ -577,9 +592,6 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
|
|||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
assert activation == "silu", (
|
|
||||||
f"{activation} not supported for Cutlass MoE.")
|
|
||||||
|
|
||||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
@ -590,27 +602,22 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
|
|||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
scoring_func=scoring_func,
|
scoring_func=scoring_func,
|
||||||
e_score_correction_bias=e_score_correction_bias)
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
|
indices_type=torch.uint32)
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe import cutlass_moe_fp8
|
return self.fused_experts(
|
||||||
|
|
||||||
return cutlass_moe_fp8(
|
|
||||||
x,
|
x,
|
||||||
layer.w13_weight.transpose(1, 2),
|
layer.w13_weight,
|
||||||
layer.w2_weight.transpose(1, 2),
|
layer.w2_weight,
|
||||||
layer.w13_weight_scale,
|
|
||||||
layer.w2_weight_scale,
|
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
self.ab_strides1,
|
activation=activation,
|
||||||
self.c_strides1,
|
global_num_experts=global_num_experts,
|
||||||
self.ab_strides2,
|
expert_map=None if self.disable_expert_map else expert_map,
|
||||||
self.c_strides2,
|
w1_scale=layer.w13_weight_scale,
|
||||||
|
w2_scale=layer.w2_weight_scale,
|
||||||
a1_scale=layer.w13_input_scale,
|
a1_scale=layer.w13_input_scale,
|
||||||
a2_scale=layer.w2_input_scale,
|
a2_scale=layer.w2_input_scale,
|
||||||
out_dtype=x.dtype,
|
|
||||||
expert_map=expert_map,
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -769,7 +769,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
del layer.w13_input_scale
|
del layer.w13_input_scale
|
||||||
del layer.w2_input_scale
|
del layer.w2_input_scale
|
||||||
|
|
||||||
def select_gemm_impl(self, prepare_finalize):
|
def select_gemm_impl(self, prepare_finalize, moe):
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
|
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
|
||||||
BatchedTritonOrDeepGemmExperts)
|
BatchedTritonOrDeepGemmExperts)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user