[CPU] Refactor CPU fused MOE (#30531)

Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
Li, Jiang 2025-12-18 14:36:49 +08:00 committed by GitHub
parent fc2ae6d617
commit e3ab93c896
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 1388 additions and 200 deletions

View File

@ -50,6 +50,7 @@ function cpu_tests() {
docker exec cpu-test-"$NUMA_NODE" bash -c " docker exec cpu-test-"$NUMA_NODE" bash -c "
set -e set -e
pytest -x -v -s tests/kernels/attention/test_cpu_attn.py pytest -x -v -s tests/kernels/attention/test_cpu_attn.py
pytest -x -v -s tests/kernels/moe/test_cpu_fused_moe.py
pytest -x -v -s tests/kernels/test_onednn.py" pytest -x -v -s tests/kernels/test_onednn.py"
# Run basic model test # Run basic model test

View File

@ -330,7 +330,7 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON
PUBLIC ${oneDNN_BINARY_DIR}/include PUBLIC ${oneDNN_BINARY_DIR}/include
PRIVATE ${oneDNN_SOURCE_DIR}/src PRIVATE ${oneDNN_SOURCE_DIR}/src
) )
target_link_libraries(dnnl_ext dnnl) target_link_libraries(dnnl_ext dnnl torch)
target_compile_options(dnnl_ext PRIVATE ${CXX_COMPILE_FLAGS} -fPIC) target_compile_options(dnnl_ext PRIVATE ${CXX_COMPILE_FLAGS} -fPIC)
list(APPEND LIBS dnnl_ext) list(APPEND LIBS dnnl_ext)
set(USE_ONEDNN ON) set(USE_ONEDNN ON)
@ -358,13 +358,13 @@ set(VLLM_EXT_SRC
"csrc/cpu/pos_encoding.cpp" "csrc/cpu/pos_encoding.cpp"
"csrc/moe/dynamic_4bit_int_moe_cpu.cpp" "csrc/moe/dynamic_4bit_int_moe_cpu.cpp"
"csrc/cpu/cpu_attn.cpp" "csrc/cpu/cpu_attn.cpp"
"csrc/cpu/scratchpad_manager.cpp"
"csrc/cpu/torch_bindings.cpp") "csrc/cpu/torch_bindings.cpp")
if (AVX512_FOUND AND NOT AVX512_DISABLED) if (AVX512_FOUND AND NOT AVX512_DISABLED)
set(VLLM_EXT_SRC set(VLLM_EXT_SRC
"csrc/cpu/shm.cpp" "csrc/cpu/shm.cpp"
"csrc/cpu/cpu_wna16.cpp" "csrc/cpu/cpu_wna16.cpp"
"csrc/cpu/cpu_fused_moe.cpp"
${VLLM_EXT_SRC}) ${VLLM_EXT_SRC})
if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI) if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI)
set(VLLM_EXT_SRC set(VLLM_EXT_SRC

View File

@ -1,5 +1,5 @@
#ifndef CPU_ATTN_MACROS_H #ifndef CPU_ARCH_MACROS_H
#define CPU_ATTN_MACROS_H #define CPU_ARCH_MACROS_H
// x86_64 // x86_64
#ifdef __x86_64__ #ifdef __x86_64__
@ -26,7 +26,7 @@
_mm512_castsi512_ps(_mm512_set1_epi32(0x42b17218)); \ _mm512_castsi512_ps(_mm512_set1_epi32(0x42b17218)); \
const __m512i vec_127 = _mm512_set1_epi32(0x0000007f); \ const __m512i vec_127 = _mm512_set1_epi32(0x0000007f); \
const int n_mantissa_bits = 23; \ const int n_mantissa_bits = 23; \
auto fast_exp = [&](vec_op::FP32Vec16& vec) __attribute__(( \ auto fast_exp = [&](const vec_op::FP32Vec16& vec) __attribute__(( \
always_inline)) { \ always_inline)) { \
__m512 values = vec.reg; \ __m512 values = vec.reg; \
auto less_ln_flt_min_mask = \ auto less_ln_flt_min_mask = \
@ -98,7 +98,7 @@
poly = vbslq_f32(hi_mask, inf, poly); \ poly = vbslq_f32(hi_mask, inf, poly); \
return vbslq_f32(lo_mask, zero, poly); \ return vbslq_f32(lo_mask, zero, poly); \
}; \ }; \
auto fast_exp = [&](vec_op::FP32Vec16& vec) \ auto fast_exp = [&](const vec_op::FP32Vec16& vec) \
__attribute__((always_inline)) { \ __attribute__((always_inline)) { \
float32x4x4_t result; \ float32x4x4_t result; \
result.val[0] = neon_expf(vec.reg.val[0]); \ result.val[0] = neon_expf(vec.reg.val[0]); \
@ -110,4 +110,4 @@
#endif // __aarch64__ #endif // __aarch64__
#endif #endif

View File

@ -8,10 +8,8 @@
#include <sys/sysctl.h> #include <sys/sysctl.h>
#endif #endif
#include "cpu_types.hpp" #include "cpu/cpu_arch_macros.h"
#include "scratchpad_manager.h" #include "cpu/utils.hpp"
#include "cpu_attn_macros.h"
#include "utils.hpp"
namespace cpu_attention { namespace cpu_attention {
enum class ISA { AMX, VEC, VEC16, NEON }; enum class ISA { AMX, VEC, VEC16, NEON };
@ -378,12 +376,13 @@ class AttentionScheduler {
static constexpr int32_t MaxQTileIterNum = 128; static constexpr int32_t MaxQTileIterNum = 128;
AttentionScheduler() : available_cache_size_(get_available_l2_size()) {} AttentionScheduler()
: available_cache_size_(cpu_utils::get_available_l2_size()) {}
torch::Tensor schedule(const ScheduleInput& input) const { torch::Tensor schedule(const ScheduleInput& input) const {
const bool casual = input.casual; const bool casual = input.casual;
const int32_t thread_num = omp_get_max_threads(); const int32_t thread_num = omp_get_max_threads();
const int64_t cache_size = get_available_l2_size(); const int64_t cache_size = cpu_utils::get_available_l2_size();
const int32_t max_num_q_per_iter = input.max_num_q_per_iter; const int32_t max_num_q_per_iter = input.max_num_q_per_iter;
const int32_t kv_len_alignment = input.kv_block_alignment; const int32_t kv_len_alignment = input.kv_block_alignment;
int32_t q_head_per_kv = input.num_heads_q / input.num_heads_kv; int32_t q_head_per_kv = input.num_heads_q / input.num_heads_kv;
@ -659,7 +658,7 @@ class AttentionScheduler {
metadata_ptr->thread_num + metadata_ptr->thread_num +
metadata_ptr->reduction_scratchpad_size_per_kv_head * metadata_ptr->reduction_scratchpad_size_per_kv_head *
(use_gqa ? input.num_heads_kv : input.num_heads_q); (use_gqa ? input.num_heads_kv : input.num_heads_q);
DNNLScratchPadManager::get_dnnl_scratchpad_manager()->realloc( cpu_utils::ScratchPadManager::get_scratchpad_manager()->realloc(
scratchpad_size); scratchpad_size);
// metadata_ptr->print(); // metadata_ptr->print();
@ -667,7 +666,7 @@ class AttentionScheduler {
// test out of boundary access // test out of boundary access
// { // {
// float* cache_ptr = // float* cache_ptr =
// DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data<float>(); // cpu_utils::ScratchPadManager::getl_scratchpad_manager()->get_data<float>();
// for (int64_t i = 0; i < scratchpad_size / sizeof(float); ++i) { // for (int64_t i = 0; i < scratchpad_size / sizeof(float); ++i) {
// cache_ptr[i] = std::numeric_limits<float>::quiet_NaN(); // cache_ptr[i] = std::numeric_limits<float>::quiet_NaN();
// } // }
@ -749,27 +748,6 @@ class AttentionScheduler {
return std::max(rounded_tile_size, round_size); return std::max(rounded_tile_size, round_size);
} }
static int64_t get_available_l2_size() {
static int64_t size = []() {
#if defined(__APPLE__)
// macOS doesn't have _SC_LEVEL2_CACHE_SIZE. Use sysctlbyname.
int64_t l2_cache_size = 0;
size_t len = sizeof(l2_cache_size);
if (sysctlbyname("hw.l2cachesize", &l2_cache_size, &len, NULL, 0) == 0 &&
l2_cache_size > 0) {
return l2_cache_size >> 1; // use 50% of L2 cache
}
// Fallback if sysctlbyname fails
return 128LL * 1024 >> 1; // use 50% of 128KB
#else
long l2_cache_size = sysconf(_SC_LEVEL2_CACHE_SIZE);
TORCH_CHECK_NE(l2_cache_size, -1);
return l2_cache_size >> 1; // use 50% of L2 cache
#endif
}();
return size;
}
private: private:
int64_t available_cache_size_; int64_t available_cache_size_;
}; };
@ -1402,7 +1380,7 @@ class AttentionMainLoop {
// init buffers // init buffers
void* scratchpad_ptr = void* scratchpad_ptr =
DNNLScratchPadManager::get_dnnl_scratchpad_manager() cpu_utils::ScratchPadManager::get_scratchpad_manager()
->get_data<void>(); ->get_data<void>();
AttentionScratchPad buffer_manager(thread_id, metadata, scratchpad_ptr); AttentionScratchPad buffer_manager(thread_id, metadata, scratchpad_ptr);
@ -1422,8 +1400,7 @@ class AttentionMainLoop {
} }
} }
const int64_t available_cache_size = const int64_t available_cache_size = cpu_utils::get_available_l2_size();
AttentionScheduler::get_available_l2_size();
const int32_t default_tile_size = const int32_t default_tile_size =
AttentionScheduler::calcu_default_tile_size( AttentionScheduler::calcu_default_tile_size(
available_cache_size, head_dim, sizeof(kv_cache_t), available_cache_size, head_dim, sizeof(kv_cache_t),

727
csrc/cpu/cpu_fused_moe.cpp Normal file
View File

@ -0,0 +1,727 @@
#include "cpu/cpu_types.hpp"
#include "cpu/utils.hpp"
#include "cpu/micro_gemm/cpu_micro_gemm_vec.hpp"
#include "cpu/cpu_arch_macros.h"
#ifdef CPU_CAPABILITY_AMXBF16
#include "cpu/micro_gemm/cpu_micro_gemm_amx.hpp"
#define AMX_DISPATCH(...) \
case cpu_utils::ISA::AMX: { \
using gemm_t = cpu_micro_gemm::MicroGemm<cpu_utils::ISA::AMX, scalar_t>; \
return __VA_ARGS__(); \
}
#else
#define AMX_DISPATCH(...) case cpu_utils::ISA::AMX:
#endif
#define CPU_ISA_DISPATCH_IMPL(ISA_TYPE, ...) \
[&] { \
switch (ISA_TYPE) { \
AMX_DISPATCH(__VA_ARGS__) \
case cpu_utils::ISA::VEC: { \
using gemm_t = \
cpu_micro_gemm::MicroGemm<cpu_utils::ISA::VEC, scalar_t>; \
return __VA_ARGS__(); \
} \
default: { \
TORCH_CHECK(false, "Invalid CPU ISA type."); \
} \
} \
}()
namespace {
enum class FusedMOEAct { SiluAndMul, SwigluOAIAndMul };
FusedMOEAct get_act_type(const std::string& act) {
if (act == "silu") {
return FusedMOEAct::SiluAndMul;
} else if (act == "swigluoai") {
return FusedMOEAct::SwigluOAIAndMul;
} else {
TORCH_CHECK(false, "Invalid act type: " + act);
}
}
template <typename scalar_t>
void swigluoai_and_mul(float* __restrict__ input, scalar_t* __restrict__ output,
const int32_t m_size, const int32_t n_size,
const int32_t input_stride,
const int32_t output_stride) {
using scalar_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
// For GPT-OSS interleaved gate-up weights
alignas(64) static int32_t index[16] = {0, 2, 4, 6, 8, 10, 12, 14,
16, 18, 20, 22, 24, 26, 28, 30};
vec_op::INT32Vec16 index_vec(index);
vec_op::FP32Vec16 gate_up_max_vec(7.0);
vec_op::FP32Vec16 up_min_vec(-7.0);
vec_op::FP32Vec16 alpha_vec(1.702);
vec_op::FP32Vec16 one_vec(1.0);
DEFINE_FAST_EXP
for (int32_t m = 0; m < m_size; ++m) {
for (int32_t n = 0; n < n_size; n += 32) {
vec_op::FP32Vec16 gate_vec(input + n, index_vec);
vec_op::FP32Vec16 up_vec(input + n + 1, index_vec);
gate_vec = gate_vec.min(gate_up_max_vec);
up_vec = up_vec.clamp(up_min_vec, gate_up_max_vec);
auto sigmoid_vec = one_vec / (one_vec + fast_exp(-gate_vec * alpha_vec));
auto glu = gate_vec * sigmoid_vec;
auto gated_output_fp32 = (one_vec + up_vec) * glu;
scalar_vec_t gated_output = scalar_vec_t(gated_output_fp32);
gated_output.save(output + n / 2);
}
input += input_stride;
output += output_stride;
}
}
template <typename scalar_t>
void silu_and_mul(float* __restrict__ input, scalar_t* __restrict__ output,
const int32_t m_size, const int32_t n_size,
const int32_t input_stride, const int32_t output_stride) {
using scalar_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
const int32_t dim = n_size / 2;
float* __restrict__ gate = input;
float* __restrict__ up = input + dim;
vec_op::FP32Vec16 one_vec(1.0);
DEFINE_FAST_EXP
for (int32_t m = 0; m < m_size; ++m) {
for (int32_t n = 0; n < dim; n += 16) {
vec_op::FP32Vec16 gate_vec(gate + n);
vec_op::FP32Vec16 up_vec(up + n);
auto sigmoid_vec = one_vec / (one_vec + fast_exp(-gate_vec));
auto silu = gate_vec * sigmoid_vec;
auto gated_output_fp32 = up_vec * silu;
scalar_vec_t gated_output = scalar_vec_t(gated_output_fp32);
gated_output.save(output + n);
}
gate += input_stride;
up += input_stride;
output += output_stride;
}
}
template <typename scalar_t>
FORCE_INLINE void apply_gated_act(const FusedMOEAct act,
float* __restrict__ input,
scalar_t* __restrict__ output,
const int32_t m, const int32_t n,
const int32_t input_stride,
const int32_t output_stride) {
switch (act) {
case FusedMOEAct::SwigluOAIAndMul:
swigluoai_and_mul(input, output, m, n, input_stride, output_stride);
return;
case FusedMOEAct::SiluAndMul:
silu_and_mul(input, output, m, n, input_stride, output_stride);
return;
default:
TORCH_CHECK(false, "Unsupported act type.");
}
}
template <typename scalar_t, typename gemm_t>
void prepack_moe_weight_impl(scalar_t* __restrict__ weight_ptr,
scalar_t* __restrict__ packed_weight_ptr,
const int32_t expert_num,
const int32_t output_size,
const int32_t input_size,
const int64_t expert_stride) {
#pragma omp parallel for
for (int32_t e_idx = 0; e_idx < expert_num; ++e_idx) {
gemm_t::pack_weight(weight_ptr + expert_stride * e_idx,
packed_weight_ptr + expert_stride * e_idx, output_size,
input_size);
}
}
template <typename scalar_t, typename w_t, typename gemm_t>
void fused_moe_impl(scalar_t* __restrict__ output, scalar_t* __restrict__ input,
w_t* __restrict__ w13, w_t* __restrict__ w2,
w_t* __restrict__ w13_bias, w_t* __restrict__ w2_bias,
float* __restrict__ topk_weights,
int32_t* __restrict__ topk_id, FusedMOEAct act_type,
const int32_t token_num, const int32_t expert_num,
const int32_t topk_num, const int32_t input_size_13,
const int32_t output_size_13, const int32_t input_size_2,
const int32_t output_size_2) {
using scalar_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
constexpr int32_t gemm_n_tile_size = gemm_t::NSize;
constexpr int32_t gemm_m_tile_size = gemm_t::MaxMSize;
constexpr int32_t min_w13_n_tile_size = 2 * gemm_n_tile_size;
static_assert(gemm_n_tile_size % 16 == 0);
TORCH_CHECK_EQ(output_size_13 % min_w13_n_tile_size, 0);
TORCH_CHECK_EQ(output_size_2 % gemm_n_tile_size, 0);
TORCH_CHECK_EQ(output_size_13 / 2, input_size_2);
const int32_t thread_num = omp_get_max_threads();
const int32_t w13_input_buffer_size = cpu_utils::round_up<64>(
gemm_m_tile_size * input_size_13 * sizeof(scalar_t));
const int32_t w13_n_tile_size = [&]() {
const int64_t cache_size = cpu_utils::get_available_l2_size();
// input buffer + output buffer + weight
const int32_t n_size_cache_limit =
(cache_size - w13_input_buffer_size) /
(gemm_m_tile_size * sizeof(float) + input_size_13 * sizeof(scalar_t));
const int32_t n_size_thread_limit =
output_size_13 / std::max(1, thread_num / topk_num);
const int32_t n_size = cpu_utils::round_down<min_w13_n_tile_size>(
std::min(n_size_cache_limit, n_size_thread_limit));
return std::max(n_size, min_w13_n_tile_size);
}();
const int32_t w2_input_tile_size = cpu_utils::round_up<64>(
gemm_m_tile_size * input_size_2 * sizeof(scalar_t));
const int32_t w2_n_tile_size = [&]() {
const int64_t cache_size = cpu_utils::get_available_l2_size();
// input tile + weight
const int32_t n_size_cache_limit =
(cache_size - w2_input_tile_size) / (input_size_2 * sizeof(scalar_t));
const int32_t n_size_thread_limit =
output_size_2 / std::max(1, thread_num / topk_num);
const int32_t n_size = cpu_utils::round_down<gemm_n_tile_size>(
std::min(n_size_cache_limit, n_size_thread_limit));
return std::max(n_size, gemm_n_tile_size);
}();
// allocate buffers
int32_t common_buffer_offset = 0;
int32_t w13_thread_buffer_offset = 0;
int32_t ws_thread_buffer_offset = 0;
// common buffers
const int32_t token_num_per_group_buffer_size =
cpu_utils::round_up<64>(expert_num * sizeof(int32_t));
const int32_t token_num_per_group_buffer_offset = common_buffer_offset;
common_buffer_offset += token_num_per_group_buffer_size;
const int32_t cu_token_num_per_group_buffer_size =
cpu_utils::round_up<64>((expert_num + 1) * sizeof(int32_t));
const int32_t cu_token_num_per_group_buffer_offset = common_buffer_offset;
common_buffer_offset += cu_token_num_per_group_buffer_size;
const int32_t expand_token_id_buffer_size =
cpu_utils::round_up<64>(token_num * topk_num * sizeof(int32_t));
const int32_t expand_token_id_buffer_offset = common_buffer_offset;
common_buffer_offset += expand_token_id_buffer_size;
const int32_t expand_token_id_index_buffer_size =
cpu_utils::round_up<64>(token_num * topk_num * sizeof(int32_t));
const int32_t expand_token_id_index_buffer_offset = common_buffer_offset;
common_buffer_offset += expand_token_id_index_buffer_size;
const int32_t w13_gemm_output_buffer_size = cpu_utils::round_up<64>(
token_num * topk_num * (output_size_13 / 2) * sizeof(scalar_t));
const int32_t w13_gemm_output_buffer_offset = common_buffer_offset;
common_buffer_offset += w13_gemm_output_buffer_size;
const int32_t w2_gemm_output_buffer_size = cpu_utils::round_up<64>(
token_num * topk_num * output_size_2 * sizeof(float));
const int32_t w2_gemm_output_buffer_offset = common_buffer_offset;
common_buffer_offset += w2_gemm_output_buffer_size;
// w13 GEMM thread buffers
const int32_t w13_input_buffer_offset = w13_thread_buffer_offset;
w13_thread_buffer_offset += w13_input_buffer_size;
const int32_t w13_output_buffer_size = cpu_utils::round_up<64>(
gemm_m_tile_size * w13_n_tile_size * sizeof(float));
const int32_t w13_output_buffer_offset = w13_thread_buffer_offset;
w13_thread_buffer_offset += w13_output_buffer_size;
// Weighted sum thread buffer
const int32_t ws_output_buffer_size =
cpu_utils::round_up<64>(output_size_2 * sizeof(float));
const int32_t ws_output_buffer_offset = ws_thread_buffer_offset;
ws_thread_buffer_offset += ws_output_buffer_size;
const int32_t buffer_size =
common_buffer_offset +
std::max(w13_thread_buffer_offset, ws_thread_buffer_offset) * thread_num;
cpu_utils::ScratchPadManager::get_scratchpad_manager()->realloc(buffer_size);
uint8_t* common_buffer_start =
cpu_utils::ScratchPadManager::get_scratchpad_manager()
->get_data<uint8_t>();
uint8_t* thread_buffer_start = common_buffer_start + common_buffer_offset;
int32_t* __restrict__ token_num_per_group_buffer = reinterpret_cast<int32_t*>(
common_buffer_start + token_num_per_group_buffer_offset);
int32_t* __restrict__ cu_token_num_per_group_buffer =
reinterpret_cast<int32_t*>(common_buffer_start +
cu_token_num_per_group_buffer_offset);
int32_t* __restrict__ expand_token_id_buffer = reinterpret_cast<int32_t*>(
common_buffer_start + expand_token_id_buffer_offset);
int32_t* __restrict__ expand_token_id_index_buffer =
reinterpret_cast<int32_t*>(common_buffer_start +
expand_token_id_index_buffer_offset);
// prepare token-expert mappings
{
std::memset(token_num_per_group_buffer, 0, expert_num * sizeof(int32_t));
for (int32_t i = 0; i < token_num * topk_num; ++i) {
int32_t curr_expert_id = topk_id[i];
++token_num_per_group_buffer[curr_expert_id];
}
int32_t token_num_sum = 0;
cu_token_num_per_group_buffer[0] = 0;
int32_t* token_index_buffer = cu_token_num_per_group_buffer + 1;
for (int32_t i = 0; i < expert_num; ++i) {
token_index_buffer[i] = token_num_sum;
token_num_sum += token_num_per_group_buffer[i];
}
for (int32_t i = 0; i < token_num; ++i) {
int32_t* curr_topk_id = topk_id + i * topk_num;
int32_t* curr_index_buffer = expand_token_id_index_buffer + i * topk_num;
for (int32_t j = 0; j < topk_num; ++j) {
int32_t curr_expert_id = curr_topk_id[j];
int32_t curr_index = token_index_buffer[curr_expert_id];
++token_index_buffer[curr_expert_id];
expand_token_id_buffer[curr_index] = i;
curr_index_buffer[j] = curr_index;
}
}
}
// w13 GEMM + act
{
alignas(64) cpu_utils::Counter counter;
cpu_utils::Counter* counter_ptr = &counter;
#pragma omp parallel for schedule(static, 1)
for (int32_t thread_id = 0; thread_id < thread_num; ++thread_id) {
const int32_t task_num_per_expert =
(output_size_13 + w13_n_tile_size - 1) / w13_n_tile_size;
const int32_t task_num = task_num_per_expert * expert_num;
uint8_t* __restrict__ thread_buffer =
thread_buffer_start + thread_id * w13_thread_buffer_offset;
scalar_t* __restrict__ w13_input_buffer =
reinterpret_cast<scalar_t*>(thread_buffer + w13_input_buffer_offset);
float* __restrict__ w13_output_buffer =
reinterpret_cast<float*>(thread_buffer + w13_output_buffer_offset);
scalar_t* __restrict__ w13_gemm_output_buffer =
reinterpret_cast<scalar_t*>(common_buffer_start +
w13_gemm_output_buffer_offset);
gemm_t gemm;
const int32_t input_size_13_bytes = input_size_13 * sizeof(scalar_t);
const int32_t w13_n_group_stride = 16 * input_size_13;
const int32_t w13_n_tile_stride = gemm_n_tile_size * input_size_13;
for (;;) {
int32_t task_id = counter_ptr->acquire_counter();
if (task_id >= task_num) {
break;
}
const int32_t curr_expert_id = task_id / task_num_per_expert;
const int32_t curr_output_group_id = task_id % task_num_per_expert;
const int32_t curr_token_num =
token_num_per_group_buffer[curr_expert_id];
if (curr_token_num == 0) {
continue;
}
const int32_t actual_n_tile_size =
std::min(w13_n_tile_size,
output_size_13 - curr_output_group_id * w13_n_tile_size);
const int32_t* __restrict__ curr_expand_token_id_buffer =
expand_token_id_buffer +
cu_token_num_per_group_buffer[curr_expert_id];
scalar_t* __restrict__ curr_w13_gemm_output_buffer =
w13_gemm_output_buffer +
cu_token_num_per_group_buffer[curr_expert_id] *
(output_size_13 / 2) +
curr_output_group_id * w13_n_tile_size / 2;
w_t* __restrict__ w13_weight_ptr_0 = nullptr;
w_t* __restrict__ w13_weight_ptr_1 = nullptr;
w_t* __restrict__ w13_bias_ptr_0 = nullptr;
w_t* __restrict__ w13_bias_ptr_1 = nullptr;
if (act_type == FusedMOEAct::SwigluOAIAndMul) {
// For SwigluOAIAndMul, up and down weights are interleaved
w13_weight_ptr_0 =
w13 + curr_expert_id * input_size_13 * output_size_13 +
curr_output_group_id * w13_n_tile_size * input_size_13;
w13_weight_ptr_1 =
w13_weight_ptr_0 + actual_n_tile_size / 2 * input_size_13;
if (w13_bias != nullptr) {
w13_bias_ptr_0 = w13_bias + curr_expert_id * output_size_13 +
curr_output_group_id * w13_n_tile_size;
w13_bias_ptr_1 = w13_bias_ptr_0 + actual_n_tile_size / 2;
}
} else {
w13_weight_ptr_0 =
w13 + curr_expert_id * input_size_13 * output_size_13 +
curr_output_group_id * (w13_n_tile_size / 2) * input_size_13;
w13_weight_ptr_1 =
w13_weight_ptr_0 + output_size_13 / 2 * input_size_13;
if (w13_bias != nullptr) {
w13_bias_ptr_0 = w13_bias + curr_expert_id * output_size_13 +
curr_output_group_id * (w13_n_tile_size / 2);
w13_bias_ptr_1 = w13_bias_ptr_0 + output_size_13 / 2;
}
}
scalar_t* __restrict__ curr_w13_input_buffer = w13_input_buffer;
for (int32_t token_idx = 0; token_idx < curr_token_num;
token_idx += gemm_m_tile_size) {
const int32_t actual_token_num =
std::min(gemm_m_tile_size, curr_token_num - token_idx);
// copy inputs
{
scalar_t* __restrict__ curr_w13_input_buffer_iter =
curr_w13_input_buffer;
for (int32_t i = 0; i < actual_token_num; ++i) {
const int32_t curr_token_id = curr_expand_token_id_buffer[i];
int8_t* __restrict__ curr_input_iter = reinterpret_cast<int8_t*>(
input + curr_token_id * input_size_13);
int8_t* __restrict__ curr_output_iter =
reinterpret_cast<int8_t*>(curr_w13_input_buffer_iter);
int32_t j = 0;
for (; j < input_size_13_bytes - 64; j += 64) {
vec_op::INT8Vec64 vec(curr_input_iter);
vec.save(curr_output_iter);
curr_input_iter += 64;
curr_output_iter += 64;
}
vec_op::INT8Vec64 vec(curr_input_iter);
vec.save(curr_output_iter, input_size_13_bytes - j);
// update
curr_w13_input_buffer_iter += input_size_13;
}
// update
curr_expand_token_id_buffer += actual_token_num;
}
// gemm + act
{
scalar_t* __restrict__ w13_weight_ptr_0_iter = w13_weight_ptr_0;
scalar_t* __restrict__ w13_weight_ptr_1_iter = w13_weight_ptr_1;
scalar_t* __restrict__ w13_bias_ptr_0_iter = w13_bias_ptr_0;
scalar_t* __restrict__ w13_bias_ptr_1_iter = w13_bias_ptr_1;
scalar_t* __restrict__ curr_w13_input_buffer_iter =
curr_w13_input_buffer;
float* __restrict__ w13_output_buffer_0_iter = w13_output_buffer;
float* __restrict__ w13_output_buffer_1_iter =
w13_output_buffer + actual_n_tile_size / 2;
for (int32_t i = 0; i < actual_n_tile_size;
i += min_w13_n_tile_size) {
gemm.gemm(curr_w13_input_buffer_iter, w13_weight_ptr_0_iter,
w13_output_buffer_0_iter, actual_token_num,
input_size_13, input_size_13, w13_n_group_stride,
actual_n_tile_size, false);
if (w13_bias != nullptr) {
cpu_micro_gemm::add_bias_epilogue<gemm_n_tile_size>(
w13_output_buffer_0_iter, w13_output_buffer_0_iter,
w13_bias_ptr_0_iter, actual_token_num, actual_n_tile_size,
actual_n_tile_size);
w13_bias_ptr_0_iter += gemm_n_tile_size;
}
gemm.gemm(curr_w13_input_buffer_iter, w13_weight_ptr_1_iter,
w13_output_buffer_1_iter, actual_token_num,
input_size_13, input_size_13, w13_n_group_stride,
actual_n_tile_size, false);
if (w13_bias != nullptr) {
cpu_micro_gemm::add_bias_epilogue<gemm_n_tile_size>(
w13_output_buffer_1_iter, w13_output_buffer_1_iter,
w13_bias_ptr_1_iter, actual_token_num, actual_n_tile_size,
actual_n_tile_size);
w13_bias_ptr_1_iter += gemm_n_tile_size;
}
// update
w13_weight_ptr_0_iter += w13_n_tile_stride;
w13_weight_ptr_1_iter += w13_n_tile_stride;
w13_output_buffer_0_iter += gemm_n_tile_size;
w13_output_buffer_1_iter += gemm_n_tile_size;
}
apply_gated_act(act_type, w13_output_buffer,
curr_w13_gemm_output_buffer, actual_token_num,
actual_n_tile_size, actual_n_tile_size,
output_size_13 / 2);
// update
curr_w13_gemm_output_buffer +=
gemm_m_tile_size * (output_size_13 / 2);
}
}
}
}
}
// w2 GEMM
{
alignas(64) cpu_utils::Counter counter;
cpu_utils::Counter* counter_ptr = &counter;
#pragma omp parallel for schedule(static, 1)
for (int32_t thread_id = 0; thread_id < thread_num; ++thread_id) {
const int32_t task_num_per_expert =
(output_size_2 + w2_n_tile_size - 1) / w2_n_tile_size;
const int32_t task_num = task_num_per_expert * expert_num;
scalar_t* __restrict__ w13_gemm_output_buffer =
reinterpret_cast<scalar_t*>(common_buffer_start +
w13_gemm_output_buffer_offset);
float* __restrict__ w2_gemm_output_buffer = reinterpret_cast<float*>(
common_buffer_start + w2_gemm_output_buffer_offset);
gemm_t gemm;
const int32_t w2_n_tile_stride = gemm_n_tile_size * input_size_2;
const int32_t w2_n_group_stride = 16 * input_size_2;
for (;;) {
int32_t task_id = counter_ptr->acquire_counter();
if (task_id >= task_num) {
break;
}
const int32_t curr_expert_id = task_id / task_num_per_expert;
const int32_t curr_output_group_id = task_id % task_num_per_expert;
const int32_t curr_token_num =
token_num_per_group_buffer[curr_expert_id];
if (curr_token_num == 0) {
continue;
}
const int32_t actual_n_tile_size =
std::min(w2_n_tile_size,
output_size_2 - curr_output_group_id * w2_n_tile_size);
scalar_t* __restrict__ curr_w13_gemm_output_buffer =
w13_gemm_output_buffer +
cu_token_num_per_group_buffer[curr_expert_id] * input_size_2;
float* __restrict__ curr_w2_gemm_output_buffer =
w2_gemm_output_buffer +
cu_token_num_per_group_buffer[curr_expert_id] * output_size_2 +
curr_output_group_id * w2_n_tile_size;
scalar_t* __restrict__ w2_weight_ptr =
w2 + curr_expert_id * output_size_2 * input_size_2 +
curr_output_group_id * w2_n_tile_size * input_size_2;
scalar_t* __restrict__ w2_bias_ptr = nullptr;
if (w2_bias != nullptr) {
w2_bias_ptr = w2_bias + curr_expert_id * output_size_2 +
curr_output_group_id * w2_n_tile_size;
}
for (int32_t token_idx = 0; token_idx < curr_token_num;
token_idx += gemm_m_tile_size) {
const int32_t actual_token_num =
std::min(gemm_m_tile_size, curr_token_num - token_idx);
scalar_t* __restrict__ w2_weight_ptr_iter = w2_weight_ptr;
scalar_t* __restrict__ w2_bias_ptr_iter = w2_bias_ptr;
float* __restrict__ curr_w2_gemm_output_buffer_iter =
curr_w2_gemm_output_buffer;
for (int32_t i = 0; i < actual_n_tile_size; i += gemm_n_tile_size) {
gemm.gemm(curr_w13_gemm_output_buffer, w2_weight_ptr_iter,
curr_w2_gemm_output_buffer_iter, actual_token_num,
input_size_2, input_size_2, w2_n_group_stride,
output_size_2, false);
if (w2_bias != nullptr) {
cpu_micro_gemm::add_bias_epilogue<gemm_n_tile_size>(
curr_w2_gemm_output_buffer_iter,
curr_w2_gemm_output_buffer_iter, w2_bias_ptr_iter,
actual_token_num, output_size_2, output_size_2);
w2_bias_ptr_iter += gemm_n_tile_size;
}
w2_weight_ptr_iter += w2_n_tile_stride;
curr_w2_gemm_output_buffer_iter += gemm_n_tile_size;
}
// update
curr_w13_gemm_output_buffer += gemm_m_tile_size * input_size_2;
curr_w2_gemm_output_buffer += gemm_m_tile_size * output_size_2;
}
}
}
}
// weighted sum
{
alignas(64) cpu_utils::Counter counter;
cpu_utils::Counter* counter_ptr = &counter;
#pragma omp parallel for schedule(static, 1)
for (int32_t thread_id = 0; thread_id < thread_num; ++thread_id) {
const int32_t task_num = token_num;
uint8_t* __restrict__ thread_buffer =
thread_buffer_start + thread_id * ws_thread_buffer_offset;
float* __restrict__ ws_output_buffer =
reinterpret_cast<float*>(thread_buffer + ws_output_buffer_offset);
float* __restrict__ w2_gemm_output_buffer = reinterpret_cast<float*>(
common_buffer_start + w2_gemm_output_buffer_offset);
for (;;) {
int32_t task_id = counter_ptr->acquire_counter();
if (task_id >= task_num) {
break;
}
int32_t token_id = task_id;
int32_t* __restrict__ curr_expand_token_id_index_buffer =
expand_token_id_index_buffer + token_id * topk_num;
float* __restrict__ curr_weight = topk_weights + token_id * topk_num;
scalar_t* __restrict__ curr_output_buffer =
output + token_id * output_size_2;
if (topk_num > 1) {
{
int32_t w2_output_idx = curr_expand_token_id_index_buffer[0];
float* __restrict__ w2_output_iter =
w2_gemm_output_buffer + w2_output_idx * output_size_2;
float* __restrict__ ws_output_buffer_iter = ws_output_buffer;
vec_op::FP32Vec16 weight_vec(curr_weight[0]);
for (int32_t i = 0; i < output_size_2; i += 16) {
vec_op::FP32Vec16 vec(w2_output_iter);
vec = vec * weight_vec;
vec.save(ws_output_buffer_iter);
// update
w2_output_iter += 16;
ws_output_buffer_iter += 16;
}
}
{
for (int32_t idx = 1; idx < topk_num - 1; ++idx) {
int32_t w2_output_idx = curr_expand_token_id_index_buffer[idx];
float* __restrict__ w2_output_iter =
w2_gemm_output_buffer + w2_output_idx * output_size_2;
float* __restrict__ ws_output_buffer_iter = ws_output_buffer;
vec_op::FP32Vec16 weight_vec(curr_weight[idx]);
for (int32_t i = 0; i < output_size_2; i += 16) {
vec_op::FP32Vec16 vec(w2_output_iter);
vec_op::FP32Vec16 sum(ws_output_buffer_iter);
sum = sum + vec * weight_vec;
sum.save(ws_output_buffer_iter);
// update
w2_output_iter += 16;
ws_output_buffer_iter += 16;
}
}
}
{
int32_t idx = topk_num - 1;
int32_t w2_output_idx = curr_expand_token_id_index_buffer[idx];
float* __restrict__ w2_output_iter =
w2_gemm_output_buffer + w2_output_idx * output_size_2;
float* __restrict__ ws_output_buffer_iter = ws_output_buffer;
scalar_t* __restrict__ curr_output_buffer_iter = curr_output_buffer;
vec_op::FP32Vec16 weight_vec(curr_weight[idx]);
for (int32_t i = 0; i < output_size_2; i += 16) {
vec_op::FP32Vec16 vec(w2_output_iter);
vec_op::FP32Vec16 sum(ws_output_buffer_iter);
sum = sum + vec * weight_vec;
scalar_vec_t out_vec(sum);
out_vec.save(curr_output_buffer_iter);
// update
w2_output_iter += 16;
ws_output_buffer_iter += 16;
curr_output_buffer_iter += 16;
}
}
} else {
int32_t w2_output_idx = curr_expand_token_id_index_buffer[0];
float* __restrict__ w2_output_iter =
w2_gemm_output_buffer + w2_output_idx * output_size_2;
scalar_t* __restrict__ curr_output_buffer_iter = curr_output_buffer;
vec_op::FP32Vec16 weight_vec(curr_weight[0]);
for (int32_t i = 0; i < output_size_2; i += 16) {
vec_op::FP32Vec16 vec(w2_output_iter);
vec = vec * weight_vec;
scalar_vec_t out_vec(vec);
out_vec.save(curr_output_buffer_iter);
// update
w2_output_iter += 16;
curr_output_buffer_iter += 16;
}
}
}
}
}
}
} // namespace
void prepack_moe_weight(
const torch::Tensor& weight, // [expert_num, output_size, input_size]
torch::Tensor& packed_weight, const std::string& isa) {
TORCH_CHECK(weight.is_contiguous());
const int32_t expert_num = weight.size(0);
const int32_t output_size = weight.size(1);
const int32_t input_size = weight.size(2);
TORCH_CHECK_EQ(output_size % 32, 0);
const int64_t expert_stride = weight.stride(0);
cpu_utils::ISA isa_type = cpu_utils::get_isa(isa);
VLLM_DISPATCH_FLOATING_TYPES(
weight.scalar_type(), "prepack_moe_weight", [&]() {
CPU_ISA_DISPATCH_IMPL(isa_type, [&]() {
scalar_t* weight_ptr = weight.data_ptr<scalar_t>();
scalar_t* packed_weight_ptr = packed_weight.data_ptr<scalar_t>();
prepack_moe_weight_impl<scalar_t, gemm_t>(
weight_ptr, packed_weight_ptr, expert_num, output_size,
input_size, expert_stride);
});
});
}
void cpu_fused_moe(
torch::Tensor& output, // [token_num, output_size_2]
const torch::Tensor& input, // [token_num, input_size_13]
const torch::Tensor&
w13, // [expert_num, output_size_13, input_size_13], packed
const torch::Tensor&
w2, // [expert_num, output_size_2, input_size_2], packed
const std::optional<torch::Tensor>&
w13_bias, // [expert_num, output_size_13]
const std::optional<torch::Tensor>& w2_bias, // [expert_num, output_size_2]
const torch::Tensor& topk_weights, // [token_num, k], float32
const torch::Tensor& topk_id, // [token_num, k], int32
const std::string& act, const std::string& isa) {
const int32_t token_num = input.size(0);
const int32_t input_size_13 = input.size(1);
const int64_t input_stride = input.stride(0);
TORCH_CHECK_EQ(input_stride, input_size_13);
const int32_t expert_num = w13.size(0);
const int32_t output_size_13 = w13.size(1);
const int32_t input_size_2 = w2.size(2);
const int32_t output_size_2 = w2.size(1);
const int32_t topk_num = topk_id.size(1);
const FusedMOEAct act_type = get_act_type(act);
cpu_utils::ISA isa_type = cpu_utils::get_isa(isa);
VLLM_DISPATCH_FLOATING_TYPES(w13.scalar_type(), "cpu_fused_moe", [&]() {
CPU_ISA_DISPATCH_IMPL(isa_type, [&]() {
fused_moe_impl<scalar_t, scalar_t, gemm_t>(
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
w13.data_ptr<scalar_t>(), w2.data_ptr<scalar_t>(),
w13_bias.has_value() ? w13_bias->data_ptr<scalar_t>() : nullptr,
w2_bias.has_value() ? w2_bias->data_ptr<scalar_t>() : nullptr,
topk_weights.data_ptr<float>(), topk_id.data_ptr<int32_t>(), act_type,
token_num, expert_num, topk_num, input_size_13, output_size_13,
input_size_2, output_size_2);
});
});
}

View File

@ -352,6 +352,10 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
explicit FP32Vec16(bool, void* ptr) explicit FP32Vec16(bool, void* ptr)
: reg((__m512)_mm512_stream_load_si512(ptr)) {} : reg((__m512)_mm512_stream_load_si512(ptr)) {}
// strided load
explicit FP32Vec16(const float* ptr, INT32Vec16 idx)
: reg(_mm512_i32gather_ps(idx.reg, ptr, 4)) {}
explicit FP32Vec16(__m512 data) : reg(data) {} explicit FP32Vec16(__m512 data) : reg(data) {}
// de-pack 4 bit values // de-pack 4 bit values
@ -408,6 +412,10 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
return FP32Vec16(_mm512_sub_ps(reg, b.reg)); return FP32Vec16(_mm512_sub_ps(reg, b.reg));
} }
FP32Vec16 operator-() const {
return FP32Vec16(_mm512_xor_ps(reg, _mm512_set1_ps(-0.0f)));
}
FP32Vec16 operator/(const FP32Vec16& b) const { FP32Vec16 operator/(const FP32Vec16& b) const {
return FP32Vec16(_mm512_div_ps(reg, b.reg)); return FP32Vec16(_mm512_div_ps(reg, b.reg));
} }

View File

@ -1,6 +1,5 @@
#include "cpu_types.hpp" #include "cpu/cpu_types.hpp"
#include "scratchpad_manager.h" #include "cpu/utils.hpp"
#include "utils.hpp"
#ifdef CPU_CAPABILITY_AMXBF16 #ifdef CPU_CAPABILITY_AMXBF16
#include "cpu/micro_gemm/cpu_micro_gemm_amx.hpp" #include "cpu/micro_gemm/cpu_micro_gemm_amx.hpp"
@ -158,7 +157,7 @@ void cpu_gemm_wna16_impl(
// a simple schedule policy, just to hold more B tiles in L2 and make sure // a simple schedule policy, just to hold more B tiles in L2 and make sure
// each thread has tasks // each thread has tasks
const int32_t n_partition_size = [&]() { const int32_t n_partition_size = [&]() {
const int64_t cache_size = cpu_utils::get_l2_size(); const int64_t cache_size = cpu_utils::get_available_l2_size();
int64_t ps_cache_limit = cache_size / (k_size * sizeof(scalar_t)); int64_t ps_cache_limit = cache_size / (k_size * sizeof(scalar_t));
int64_t ps_thread_limit = n_size / thread_num; int64_t ps_thread_limit = n_size / thread_num;
ps_cache_limit = ps_cache_limit =
@ -179,8 +178,8 @@ void cpu_gemm_wna16_impl(
const int64_t b_buffer_offset = 0; const int64_t b_buffer_offset = 0;
const int64_t c_buffer_offset = b_buffer_size; const int64_t c_buffer_offset = b_buffer_size;
const int64_t buffer_size = b_buffer_size + c_buffer_size; const int64_t buffer_size = b_buffer_size + c_buffer_size;
DNNLScratchPadManager::get_dnnl_scratchpad_manager()->realloc(buffer_size * cpu_utils::ScratchPadManager::get_scratchpad_manager()->realloc(buffer_size *
thread_num); thread_num);
alignas(64) cpu_utils::Counter counter; alignas(64) cpu_utils::Counter counter;
cpu_utils::Counter* counter_ptr = &counter; cpu_utils::Counter* counter_ptr = &counter;
@ -190,9 +189,10 @@ void cpu_gemm_wna16_impl(
scalar_t* __restrict__ b_buffer = nullptr; scalar_t* __restrict__ b_buffer = nullptr;
float* __restrict__ c_buffer = nullptr; float* __restrict__ c_buffer = nullptr;
{ {
uint8_t* buffer_ptr = DNNLScratchPadManager::get_dnnl_scratchpad_manager() uint8_t* buffer_ptr =
->get_data<uint8_t>() + cpu_utils::ScratchPadManager::get_scratchpad_manager()
thread_id * buffer_size; ->get_data<uint8_t>() +
thread_id * buffer_size;
b_buffer = reinterpret_cast<scalar_t*>(buffer_ptr + b_buffer_offset); b_buffer = reinterpret_cast<scalar_t*>(buffer_ptr + b_buffer_offset);
c_buffer = reinterpret_cast<float*>(buffer_ptr + c_buffer_offset); c_buffer = reinterpret_cast<float*>(buffer_ptr + c_buffer_offset);
} }

View File

@ -4,8 +4,8 @@
#include "common/memory_desc.hpp" #include "common/memory_desc.hpp"
#include "common/memory.hpp" #include "common/memory.hpp"
#include "dnnl_helper.h" #include "cpu/utils.hpp"
#include "scratchpad_manager.h" #include "cpu/dnnl_helper.h"
static dnnl::engine& default_engine() { static dnnl::engine& default_engine() {
static dnnl::engine engine(dnnl::engine::kind::cpu, 0); static dnnl::engine engine(dnnl::engine::kind::cpu, 0);
@ -274,7 +274,7 @@ void W8A8MatMulPrimitiveHandler::execute(ExecArgs& args) {
auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(5); auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(5);
scratchpad_storage->set_data_handle( scratchpad_storage->set_data_handle(
DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data<void>()); cpu_utils::ScratchPadManager::get_scratchpad_manager()->get_data<void>());
matmul.execute(default_stream(), memory_cache_); matmul.execute(default_stream(), memory_cache_);
default_stream().wait(); default_stream().wait();
@ -294,7 +294,7 @@ dnnl::matmul W8A8MatMulPrimitiveHandler::get_matmul_cache(
return m_size_cache_->get_or_create(key, [&]() { return m_size_cache_->get_or_create(key, [&]() {
dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false); dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false);
auto manager = DNNLScratchPadManager::get_dnnl_scratchpad_manager(); auto manager = cpu_utils::ScratchPadManager::get_scratchpad_manager();
manager->realloc(desc.scratchpad_desc().get_size()); manager->realloc(desc.scratchpad_desc().get_size());
return dnnl::matmul(desc); return dnnl::matmul(desc);
}); });
@ -470,7 +470,7 @@ void MatMulPrimitiveHandler::execute(ExecArgs& args) {
auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(3); auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(3);
scratchpad_storage->set_data_handle( scratchpad_storage->set_data_handle(
DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data<void>()); cpu_utils::ScratchPadManager::get_scratchpad_manager()->get_data<void>());
matmul.execute(default_stream(), memory_cache_); matmul.execute(default_stream(), memory_cache_);
default_stream().wait(); default_stream().wait();
@ -486,7 +486,7 @@ dnnl::matmul MatMulPrimitiveHandler::get_matmul_cache(
} }
return m_size_cache_->get_or_create(key, [&]() { return m_size_cache_->get_or_create(key, [&]() {
dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false); dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false);
auto manager = DNNLScratchPadManager::get_dnnl_scratchpad_manager(); auto manager = cpu_utils::ScratchPadManager::get_scratchpad_manager();
manager->realloc(desc.scratchpad_desc().get_size()); manager->realloc(desc.scratchpad_desc().get_size());
return dnnl::matmul(desc); return dnnl::matmul(desc);
}); });

View File

@ -235,6 +235,39 @@ class MicroGemm<cpu_utils::ISA::AMX, scalar_t> {
} }
} }
static void pack_weight(const scalar_t* __restrict__ weight,
scalar_t* __restrict__ packed_weight,
const int32_t output_size, const int32_t input_size) {
constexpr int32_t elem_num_per_group = 4 / sizeof(scalar_t);
TORCH_CHECK_EQ(output_size % 16, 0);
TORCH_CHECK_EQ(input_size % (16 * elem_num_per_group), 0);
const int32_t output_group_num = output_size / 16;
const int32_t input_32b_num = input_size / elem_num_per_group;
for (int32_t output_group_idx = 0; output_group_idx < output_group_num;
++output_group_idx) {
const int32_t* __restrict__ weight_32b =
reinterpret_cast<const int32_t*>(weight);
int32_t* __restrict__ packed_weight_32b =
reinterpret_cast<int32_t*>(packed_weight);
for (int32_t output_idx = 0; output_idx < 16; ++output_idx) {
for (int32_t weight_offset = 0, packed_offset = 0;
weight_offset < input_32b_num;
++weight_offset, packed_offset += 16) {
packed_weight_32b[packed_offset] = weight_32b[weight_offset];
}
// update
weight_32b += input_32b_num;
packed_weight_32b += 1;
}
// update
weight += 16 * input_size;
packed_weight += 16 * input_size;
}
}
private: private:
alignas(64) __tilecfg amx_tile_config_; alignas(64) __tilecfg amx_tile_config_;
int32_t curr_m_; int32_t curr_m_;

View File

@ -13,6 +13,9 @@ namespace cpu_micro_gemm {
#define CPU_MICRO_GEMM_PARAMS \ #define CPU_MICRO_GEMM_PARAMS \
a_ptr, b_ptr, c_ptr, m, k, lda, b_n_group_stride, ldc, accum_c a_ptr, b_ptr, c_ptr, m, k, lda, b_n_group_stride, ldc, accum_c
// Note: weights for MicroGemm should be packed as (output_size / 16) contiguous
// blocks, means the logical shape of blocks is [16, input_size]. And the actual
// layout of blocks can be ISA-specific.
template <cpu_utils::ISA isa, typename scalar_t> template <cpu_utils::ISA isa, typename scalar_t>
class MicroGemm { class MicroGemm {
public: public:
@ -86,6 +89,41 @@ FORCE_INLINE void bias_epilogue(float* __restrict__ c_ptr,
curr_d += ldd; curr_d += ldd;
} }
} }
template <int32_t n_size, typename scalar_t>
FORCE_INLINE void add_bias_epilogue(float* c_ptr, float* d_ptr,
scalar_t* __restrict__ bias_ptr,
const int32_t m, const int64_t ldc,
const int64_t ldd) {
using scalar_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
static_assert(n_size % 16 == 0);
constexpr int32_t n_group_num = n_size / 16;
static_assert(n_group_num <= 16);
vec_op::FP32Vec16 bias_vecs[n_group_num];
scalar_t* __restrict__ curr_bias = bias_ptr;
vec_op::unroll_loop<int32_t, n_group_num>([&](int32_t i) {
scalar_vec_t vec(curr_bias);
bias_vecs[i] = vec_op::FP32Vec16(vec);
curr_bias += 16;
});
float* curr_c = c_ptr;
float* curr_d = d_ptr;
for (int32_t i = 0; i < m; ++i) {
float* curr_c_iter = curr_c;
float* curr_d_iter = curr_d;
vec_op::unroll_loop<int32_t, n_group_num>([&](int32_t n_g_idx) {
vec_op::FP32Vec16 c_vec_fp32(curr_c_iter);
c_vec_fp32 = c_vec_fp32 + bias_vecs[n_g_idx];
c_vec_fp32.save(curr_d_iter);
curr_c_iter += 16;
curr_d_iter += 16;
});
curr_c += ldc;
curr_d += ldd;
}
}
} // namespace cpu_micro_gemm } // namespace cpu_micro_gemm
#endif #endif

View File

@ -109,6 +109,25 @@ class MicroGemm<cpu_utils::ISA::VEC, scalar_t> {
void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) { void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) {
TileGemm82<scalar_t>::gemm(CPU_MICRO_GEMM_PARAMS); TileGemm82<scalar_t>::gemm(CPU_MICRO_GEMM_PARAMS);
} }
// Note: pack contiguous weight [output_size, input_size] as contiguous
// packed weight [output_size / 16, input_size, 16]
static void pack_weight(const scalar_t* __restrict__ weight,
scalar_t* __restrict__ packed_weight,
const int32_t output_size, const int32_t input_size) {
TORCH_CHECK_EQ(output_size % 16, 0);
for (int32_t o_idx = 0; o_idx < output_size; ++o_idx) {
const scalar_t* __restrict__ curr_weight = weight + o_idx * input_size;
scalar_t* __restrict__ curr_packed_weight =
packed_weight + (o_idx / 16) * (16 * input_size) + o_idx % 16;
for (int32_t i_idx = 0; i_idx < input_size; ++i_idx) {
*curr_packed_weight = *curr_weight;
curr_packed_weight += 16;
++curr_weight;
}
}
}
}; };
} // namespace cpu_micro_gemm } // namespace cpu_micro_gemm

View File

@ -1,23 +0,0 @@
#include <cstdlib>
#include "scratchpad_manager.h"
DNNLScratchPadManager::DNNLScratchPadManager() : size_(0), ptr_(nullptr) {
this->realloc(allocation_unit * 128);
}
void DNNLScratchPadManager::realloc(size_t new_size) {
new_size = round(new_size);
if (new_size > size_) {
if (ptr_ != nullptr) {
std::free(ptr_);
}
ptr_ = std::aligned_alloc(64, new_size);
size_ = new_size;
}
}
DNNLScratchPadManager* DNNLScratchPadManager::get_dnnl_scratchpad_manager() {
static DNNLScratchPadManager manager;
return &manager;
}

View File

@ -1,31 +0,0 @@
#ifndef SCRATCHPAD_MANAGER_H
#define SCRATCHPAD_MANAGER_H
#include <cstddef>
#include <cstdio>
class DNNLScratchPadManager {
public:
static constexpr size_t allocation_unit = 4 * 1024; // 4KB
static DNNLScratchPadManager* get_dnnl_scratchpad_manager();
DNNLScratchPadManager();
template <typename T>
T* get_data() {
return reinterpret_cast<T*>(ptr_);
}
static size_t round(size_t size) {
return ((size + allocation_unit - 1) / allocation_unit) * allocation_unit;
}
void realloc(size_t new_size);
private:
size_t size_;
void* ptr_;
};
#endif

View File

@ -110,6 +110,17 @@ void cpu_gemm_wna16(const torch::Tensor& input, const torch::Tensor& q_weight,
const std::optional<torch::Tensor>& bias, const std::optional<torch::Tensor>& bias,
const int64_t pack_factor, const std::string& isa_hint); const int64_t pack_factor, const std::string& isa_hint);
void prepack_moe_weight(const torch::Tensor& weight,
torch::Tensor& packed_weight, const std::string& isa);
void cpu_fused_moe(torch::Tensor& output, const torch::Tensor& input,
const torch::Tensor& w13, const torch::Tensor& w2,
const std::optional<torch::Tensor>& w13_bias,
const std::optional<torch::Tensor>& w2_bias,
const torch::Tensor& topk_weights,
const torch::Tensor& topk_id, const std::string& act,
const std::string& isa);
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops // vLLM custom ops
@ -296,6 +307,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"pack_factor, str isa_hint) -> ()"); "pack_factor, str isa_hint) -> ()");
ops.impl("cpu_gemm_wna16", torch::kCPU, &cpu_gemm_wna16); ops.impl("cpu_gemm_wna16", torch::kCPU, &cpu_gemm_wna16);
#endif #endif
// fused moe
#if defined(__AVX512F__)
ops.def(
"prepack_moe_weight(Tensor weight, Tensor(a1!) packed_weight, str isa) "
"-> ()");
ops.impl("prepack_moe_weight", torch::kCPU, &prepack_moe_weight);
ops.def(
"cpu_fused_moe(Tensor(a0!) output, Tensor input, Tensor w13, Tensor w2, "
"Tensor? w13_bias, Tensor? w2_bias, Tensor topk_weights, Tensor topk_id, "
"str act, str isa) -> ()");
ops.impl("cpu_fused_moe", torch::kCPU, &cpu_fused_moe);
#endif
} }
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) { TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {

View File

@ -10,7 +10,7 @@
#define gettid() syscall(SYS_gettid) #define gettid() syscall(SYS_gettid)
#endif #endif
#include "cpu_types.hpp" #include "cpu/utils.hpp"
#ifdef VLLM_NUMA_DISABLED #ifdef VLLM_NUMA_DISABLED
std::string init_cpu_threads_env(const std::string& cpu_ids) { std::string init_cpu_threads_env(const std::string& cpu_ids) {
@ -138,4 +138,26 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) {
return ss.str(); return ss.str();
} }
namespace cpu_utils {
ScratchPadManager::ScratchPadManager() : size_(0), ptr_(nullptr) {
this->realloc(allocation_unit * 128);
}
void ScratchPadManager::realloc(size_t new_size) {
new_size = round(new_size);
if (new_size > size_) {
if (ptr_ != nullptr) {
std::free(ptr_);
}
ptr_ = std::aligned_alloc(64, new_size);
size_ = new_size;
}
}
ScratchPadManager* ScratchPadManager::get_scratchpad_manager() {
static ScratchPadManager manager;
return &manager;
}
} // namespace cpu_utils
#endif #endif

View File

@ -2,19 +2,24 @@
#define UTILS_HPP #define UTILS_HPP
#include <atomic> #include <atomic>
#include <cassert>
#include <cstdint>
#include <unistd.h> #include <unistd.h>
#include <ATen/cpu/Utils.h>
#if defined(__APPLE__) #include "cpu/cpu_types.hpp"
#include <sys/sysctl.h>
#endif
#include "cpu_types.hpp"
namespace cpu_utils { namespace cpu_utils {
enum class ISA { AMX, VEC }; enum class ISA { AMX, VEC };
inline ISA get_isa(const std::string& isa) {
if (isa == "amx") {
return ISA::AMX;
} else if (isa == "vec") {
return ISA::VEC;
} else {
TORCH_CHECK(false, "Invalid isa type: " + isa);
}
}
template <typename T> template <typename T>
struct VecTypeTrait { struct VecTypeTrait {
using vec_t = void; using vec_t = void;
@ -48,26 +53,66 @@ struct Counter {
int64_t acquire_counter() { return counter++; } int64_t acquire_counter() { return counter++; }
}; };
inline int64_t get_l2_size() { inline int64_t get_available_l2_size() {
static int64_t size = []() { static int64_t size = []() {
#if defined(__APPLE__) const uint32_t l2_cache_size = at::cpu::L2_cache_size();
// macOS doesn't have _SC_LEVEL2_CACHE_SIZE. Use sysctlbyname.
int64_t l2_cache_size = 0;
size_t len = sizeof(l2_cache_size);
if (sysctlbyname("hw.l2cachesize", &l2_cache_size, &len, NULL, 0) == 0 &&
l2_cache_size > 0) {
return l2_cache_size >> 1; // use 50% of L2 cache
}
// Fallback if sysctlbyname fails
return 128LL * 1024 >> 1; // use 50% of 128KB
#else
long l2_cache_size = sysconf(_SC_LEVEL2_CACHE_SIZE);
assert(l2_cache_size != -1);
return l2_cache_size >> 1; // use 50% of L2 cache return l2_cache_size >> 1; // use 50% of L2 cache
#endif
}(); }();
return size; return size;
} }
template <int32_t alignment_v, typename T>
inline T round_up(T size) {
T alignment = alignment_v;
return (((size + alignment - 1) / alignment) * alignment);
}
template <int32_t alignment_v, typename T>
inline T round_down(T size) {
T alignment = alignment_v;
return (size / alignment) * alignment;
}
template <typename T>
inline void print_logits(const char* name, T* ptr, int32_t row, int32_t col,
int32_t stride) {
std::stringstream ss;
ss << std::fixed << std::setprecision(5) << name << ": [\n";
auto* curr_logits_buffer = ptr;
for (int32_t m = 0; m < row; ++m) {
for (int32_t n = 0; n < col; ++n) {
ss << curr_logits_buffer[n] << ", ";
}
ss << "\n";
curr_logits_buffer += stride;
}
ss << "]\n";
std::printf("%s", ss.str().c_str());
}
class ScratchPadManager {
public:
static constexpr size_t allocation_unit = 4 * 1024; // 4KB
static ScratchPadManager* get_scratchpad_manager();
ScratchPadManager();
template <typename T>
T* get_data() {
return reinterpret_cast<T*>(ptr_);
}
static size_t round(size_t size) {
return ((size + allocation_unit - 1) / allocation_unit) * allocation_unit;
}
void realloc(size_t new_size);
private:
size_t size_;
void* ptr_;
};
} // namespace cpu_utils } // namespace cpu_utils
#endif #endif

View File

@ -147,7 +147,9 @@ WORKDIR /workspace/vllm
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
--mount=type=cache,target=/var/lib/apt,sharing=locked \ --mount=type=cache,target=/var/lib/apt,sharing=locked \
apt-get install -y --no-install-recommends vim numactl xz-utils apt-get install -y --no-install-recommends vim numactl xz-utils make clangd-14
RUN ln -s /usr/bin/clangd-14 /usr/bin/clangd
# install development dependencies (for testing) # install development dependencies (for testing)
RUN --mount=type=cache,target=/root/.cache/uv \ RUN --mount=type=cache,target=/root/.cache/uv \

View File

@ -1,7 +1,7 @@
cmake>=3.26.1 cmake>=3.26.1
ninja ninja
packaging>=24.2 packaging>=24.2
setuptools>=77.0.3,<81.0.0 setuptools==77.0.3 # this version can reuse CMake build dir
setuptools-scm>=8 setuptools-scm>=8
torch==2.9.1+cpu; platform_machine == "x86_64" or platform_machine == "s390x" torch==2.9.1+cpu; platform_machine == "x86_64" or platform_machine == "s390x"
torch==2.9.1; platform_system == "Darwin" or platform_machine == "ppc64le" or platform_machine == "aarch64" torch==2.9.1; platform_system == "Darwin" or platform_machine == "ppc64le" or platform_machine == "aarch64"

View File

@ -1,6 +1,8 @@
# Common dependencies # Common dependencies
-r common.txt -r common.txt
setuptools==77.0.3 # this version can reuse CMake build dir
numba == 0.61.2; platform_machine != "s390x" # Required for N-gram speculative decoding numba == 0.61.2; platform_machine != "s390x" # Required for N-gram speculative decoding
# Dependencies for CPUs # Dependencies for CPUs

View File

@ -0,0 +1,172 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
from vllm._custom_ops import cpu_fused_moe, cpu_prepack_moe_weight
from vllm.model_executor.layers.activation import SiluAndMul, SwigluOAIAndMul
from vllm.platforms import current_platform
if not current_platform.is_cpu():
pytest.skip("skipping CPU-only tests", allow_module_level=True)
EXPERT_NUM = [
8,
]
HIDDEN_DIM = [128, 2880]
INTERMEDIATE_DIM = [128, 2880]
BATCH_SIZE = [1, 64, 256]
ACT = ["silu", "swigluoai"]
USE_BIAS = [True, False]
ISA = ["amx", "vec"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
DTYPE = [torch.bfloat16]
_CPU_MOE_ACT = {
"silu": SiluAndMul(),
"swigluoai": SwigluOAIAndMul(),
}
def ref_fused_moe(
input: torch.Tensor,
w13: torch.Tensor,
w2: torch.Tensor,
w13_bias: torch.Tensor | None,
w2_bias: torch.Tensor | None,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
) -> torch.Tensor:
len_experts = w13.size(0)
cnts = topk_ids.new_zeros((topk_ids.shape[0], len_experts))
cnts.scatter_(1, topk_ids.to(torch.int64), 1)
tokens_per_expert = cnts.sum(dim=0)
idxs = topk_ids.view(-1).argsort()
sorted_tokens = input[idxs // topk_ids.shape[1]]
tokens_per_expert = tokens_per_expert.cpu().numpy()
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
tokens_for_this_expert = sorted_tokens[start_idx:end_idx].float()
curr_w13 = w13[i].float()
curr_w2 = w2[i].float()
curr_w13_bias = None
if w13_bias is not None:
curr_w13_bias = w13_bias[i].float()
curr_w2_bias = None
if w2_bias is not None:
curr_w2_bias = w2_bias[i].float()
gate_up = torch.nn.functional.linear(
tokens_for_this_expert, curr_w13, curr_w13_bias
)
# Note: to simulate the kernel implementation
gate_up = (
_CPU_MOE_ACT[activation]
.forward_native(gate_up)
.to(dtype=input.dtype)
.float()
)
expert_out = torch.nn.functional.linear(gate_up, curr_w2, curr_w2_bias)
outputs.append(expert_out)
start_idx = end_idx
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
new_x = torch.empty_like(outs)
new_x[idxs] = outs
final_out = (
new_x.view(*topk_ids.shape, -1)
.mul_(topk_weights.unsqueeze(dim=-1))
.sum(dim=1)
.type(input.dtype)
)
return final_out
@pytest.mark.parametrize("batch_size", BATCH_SIZE)
@pytest.mark.parametrize("expert_num", EXPERT_NUM)
@pytest.mark.parametrize("hidden_size", HIDDEN_DIM)
@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_DIM)
@pytest.mark.parametrize("use_bias", USE_BIAS)
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("act", ACT)
@pytest.mark.parametrize("isa", ISA)
def test_cpu_fused_moe(
batch_size: int,
expert_num: int,
hidden_size: int,
intermediate_size: int,
use_bias: bool,
dtype: torch.dtype,
act: str,
isa: str,
):
current_platform.seed_everything(0)
topk_num = max(expert_num // 2, 1)
up_dim = 2 * intermediate_size
input = torch.randn((batch_size, hidden_size), dtype=dtype) / (
0.5 * hidden_size**0.5
)
w13 = torch.randn((expert_num, up_dim, hidden_size), dtype=dtype) / (
0.5 * hidden_size**0.5
)
w2 = torch.randn((expert_num, hidden_size, intermediate_size), dtype=dtype) / (
0.5 * intermediate_size**0.5
)
router_logits = torch.randn((batch_size, expert_num), dtype=dtype)
w13_bias = None
w2_bias = None
if use_bias:
w13_bias = torch.randn((expert_num, up_dim), dtype=dtype) / (0.5 * up_dim**0.5)
w2_bias = torch.randn((expert_num, hidden_size), dtype=dtype) / (
0.5 * hidden_size**0.5
)
score = torch.softmax(router_logits, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk_num)
topk_ids = topk_ids.to(torch.int32)
ref_output = ref_fused_moe(
input,
w13,
w2,
w13_bias,
w2_bias,
topk_weight,
topk_ids,
act,
)
packed_w13 = cpu_prepack_moe_weight(w13, isa)
packed_w2 = cpu_prepack_moe_weight(w2, isa)
output = cpu_fused_moe(
input,
packed_w13,
packed_w2,
w13_bias,
w2_bias,
topk_weight,
topk_ids,
act,
isa,
)
atol, rtol = get_default_atol(output), get_default_rtol(output)
(
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol),
f"{torch.max(torch.abs(output - ref_output))}",
)

View File

@ -2919,6 +2919,42 @@ def cpu_gemm_wna16(
return output return output
def cpu_prepack_moe_weight(
weight: torch.Tensor,
isa: str,
) -> torch.Tensor:
output = torch.empty_like(weight)
torch.ops._C.prepack_moe_weight(weight, output, isa)
return output
def cpu_fused_moe(
input: torch.Tensor,
w13: torch.Tensor,
w2: torch.Tensor,
w13_bias: torch.Tensor | None,
w2_bias: torch.Tensor | None,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
act: str,
isa: str,
) -> torch.Tensor:
output = torch.empty_like(input)
torch.ops._C.cpu_fused_moe(
output,
input,
w13,
w2,
w13_bias,
w2_bias,
topk_weights,
topk_ids,
act,
isa,
)
return output
if hasattr(torch.ops._qutlass_C, "matmul_mxf4_bf16_tn"): if hasattr(torch.ops._qutlass_C, "matmul_mxf4_bf16_tn"):
@register_fake("_qutlass_C::matmul_mxf4_bf16_tn") @register_fake("_qutlass_C::matmul_mxf4_bf16_tn")

View File

@ -1,12 +1,22 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import weakref
from collections.abc import Callable from collections.abc import Callable
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm._custom_ops import cpu_fused_moe, cpu_prepack_moe_weight
from vllm.model_executor.layers.activation import SiluAndMul, SwigluOAIAndMul from vllm.model_executor.layers.activation import SiluAndMul, SwigluOAIAndMul
from vllm.model_executor.layers.quantization.utils.layer_utils import replace_parameter
from vllm.utils.torch_utils import direct_register_custom_op
_CPU_MOE_LAYER_CACHE = {}
_CPU_MOE_ACT = {
"silu": SiluAndMul(),
"swigluoai": SwigluOAIAndMul(),
}
def grouped_topk( def grouped_topk(
@ -174,8 +184,105 @@ class SGLFusedMOE:
class CPUFusedMOE: class CPUFusedMOE:
def __init__(self, layer: torch.nn.Module) -> None: def __init__(self, layer: torch.nn.Module) -> None:
use_onednn_mm = ops._supports_onednn and ops.is_onednn_acl_supported() use_grouped_gemm, isa = self.check_grouped_gemm(layer)
self.isa = isa
if use_grouped_gemm:
self.forward_method = self.forward_grouped_gemm
self.init_moe_grouped_gemm(layer=layer)
else:
self.forward_method = self.forward_torch
self.init_moe_torch(layer=layer)
def __call__(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
assert activation in _CPU_MOE_ACT, f"{activation} is not supported."
assert not apply_router_weight_on_input
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
)
return self.forward_method(
layer,
x,
topk_weights,
topk_ids,
activation,
global_num_experts,
)
def check_grouped_gemm(
self,
layer: torch.nn.Module,
) -> tuple[bool, str]:
if not hasattr(torch.ops._C, "prepack_moe_weight"):
return False, "none"
dtype = layer.w13_weight.dtype
w13_input_size = layer.w13_weight.size(2)
w13_output_size = layer.w13_weight.size(1)
w2_input_size = layer.w2_weight.size(2)
w2_output_size = layer.w2_weight.size(1)
if not (w13_output_size % 32 == 0 and w2_output_size % 32 == 0):
return False, "none"
supports_amx = torch._C._cpu._is_amx_tile_supported()
if (
supports_amx
and dtype == torch.bfloat16
and w13_input_size % 32 == 0
and w2_input_size % 32 == 0
):
return True, "amx"
if supports_amx:
return False, "none"
return True, "vec"
def init_moe_grouped_gemm(
self,
layer: torch.nn.Module,
) -> None:
new_w13 = cpu_prepack_moe_weight(layer.w13_weight, self.isa)
replace_parameter(layer, "w13_weight", new_w13)
new_w2 = cpu_prepack_moe_weight(layer.w2_weight, self.isa)
replace_parameter(layer, "w2_weight", new_w2)
def init_moe_torch(
self,
layer: torch.nn.Module,
) -> None:
use_onednn_mm = ops._supports_onednn and ops.is_onednn_acl_supported()
num_experts = layer.w13_weight.size(0) num_experts = layer.w13_weight.size(0)
has_w13_bias = hasattr(layer, "w13_bias") has_w13_bias = hasattr(layer, "w13_bias")
has_w2_bias = hasattr(layer, "w2_bias") has_w2_bias = hasattr(layer, "w2_bias")
@ -208,85 +315,112 @@ class CPUFusedMOE:
layer.down_linear.append( layer.down_linear.append(
lambda x, w=layer_w2_weight, b=layer_w2_bias: F.linear(x, w, b) lambda x, w=layer_w2_weight, b=layer_w2_bias: F.linear(x, w, b)
) )
if use_onednn_mm: # remove weight if use_onednn_mm: # remove weight
layer.w13_weight = torch.nn.Parameter(torch.empty(0), requires_grad=False) layer.w13_weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
layer.w2_weight = torch.nn.Parameter(torch.empty(0), requires_grad=False) layer.w2_weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
self.act_to_impl = { _CPU_MOE_LAYER_CACHE[id(layer)] = weakref.ref(layer)
"silu": SiluAndMul(),
"swigluoai": SwigluOAIAndMul(),
}
def __call__( def forward_grouped_gemm(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, input: torch.Tensor,
use_grouped_topk: bool, topk_weights: torch.Tensor,
top_k: int, topk_ids: torch.Tensor,
router_logits: torch.Tensor, activation: str,
renormalize: bool,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
assert activation in self.act_to_impl, f"{activation} is not supported." output = cpu_fused_moe(
assert not apply_router_weight_on_input input,
topk_weights, topk_ids = select_experts( layer.w13_weight,
hidden_states=x, layer.w2_weight,
router_logits=router_logits, getattr(layer, "w13_bias", None),
use_grouped_topk=use_grouped_topk, getattr(layer, "w2_bias", None),
top_k=top_k, topk_weights,
renormalize=renormalize, topk_ids,
topk_group=topk_group, activation,
num_expert_group=num_expert_group, self.isa,
custom_routing_function=custom_routing_function, )
scoring_func=scoring_func, return output
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias, def forward_torch(
self,
layer: torch.nn.Module,
input: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int = -1,
) -> torch.Tensor:
output = torch.empty_like(input)
layer_id = id(layer)
torch.ops.vllm.cpu_fused_moe_torch(
layer_id,
output,
input,
topk_weights,
topk_ids,
activation,
global_num_experts,
) )
# Ref code from https://github.com/sgl-project/sglang/blob/716e682721397df103f347d22da8bd46c6016dab/python/sglang/srt/layers/moe/fused_moe_native.py#L53 return output
len_experts = global_num_experts
cnts = topk_ids.new_zeros((topk_ids.shape[0], len_experts))
cnts.scatter_(1, topk_ids.to(torch.int64), 1)
tokens_per_expert = cnts.sum(dim=0)
idxs = topk_ids.view(-1).argsort()
sorted_tokens = x[idxs // topk_ids.shape[1]] def cpu_fused_moe_torch(
tokens_per_expert = tokens_per_expert.cpu().numpy() layer_id: int,
output: torch.Tensor,
input: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int = -1,
) -> None:
layer = _CPU_MOE_LAYER_CACHE[layer_id]()
outputs = [] # Ref code from https://github.com/sgl-project/sglang/blob/716e682721397df103f347d22da8bd46c6016dab/python/sglang/srt/layers/moe/fused_moe_native.py#L53
start_idx = 0 len_experts = global_num_experts
for i, num_tokens in enumerate(tokens_per_expert): cnts = topk_ids.new_zeros((topk_ids.shape[0], len_experts))
end_idx = start_idx + num_tokens cnts.scatter_(1, topk_ids.to(torch.int64), 1)
if num_tokens == 0: tokens_per_expert = cnts.sum(dim=0)
continue idxs = topk_ids.view(-1).argsort()
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
gate_up = layer.gate_up_linear[i](tokens_for_this_expert) sorted_tokens = input[idxs // topk_ids.shape[1]]
gate_up = self.act_to_impl[activation].forward_native(gate_up) tokens_per_expert = tokens_per_expert.cpu().numpy()
expert_out = layer.down_linear[i](gate_up)
outputs.append(expert_out)
start_idx = end_idx
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) outputs = []
new_x = torch.empty_like(outs) start_idx = 0
new_x[idxs] = outs for i, num_tokens in enumerate(tokens_per_expert):
final_out = ( end_idx = start_idx + num_tokens
new_x.view(*topk_ids.shape, -1) if num_tokens == 0:
.type(topk_weights.dtype) continue
.mul_(topk_weights.unsqueeze(dim=-1)) tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
.sum(dim=1)
.type(new_x.dtype) gate_up = layer.gate_up_linear[i](tokens_for_this_expert) # type: ignore
) gate_up = _CPU_MOE_ACT[activation].forward_native(gate_up)
return final_out expert_out = layer.down_linear[i](gate_up) # type: ignore
outputs.append(expert_out)
start_idx = end_idx
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
new_x = torch.empty_like(outs)
new_x[idxs] = outs
final_out = (
new_x.view(*topk_ids.shape, -1)
.type(topk_weights.dtype)
.mul_(topk_weights.unsqueeze(dim=-1))
.sum(dim=1)
.type(new_x.dtype)
)
output.copy_(final_out)
direct_register_custom_op(
op_name="cpu_fused_moe_torch",
op_func=cpu_fused_moe_torch,
mutates_args=["output"],
)

View File

@ -1726,9 +1726,10 @@ class FusedMoE(CustomOp):
return states return states
if self.shared_experts is None: if self.shared_experts is None:
if current_platform.is_tpu(): if current_platform.is_tpu() or current_platform.is_cpu():
# TODO: Once the OOM issue for the TPU backend is resolved, we # TODO: Once the OOM issue for the TPU backend is resolved, we
# will switch to using the moe_forward custom op. # will switch to using the moe_forward custom op.
# Note: CPU doesn't require wrapped forward_impl.
fused_output = self.forward_impl(hidden_states, router_logits) fused_output = self.forward_impl(hidden_states, router_logits)
assert not isinstance(fused_output, tuple) assert not isinstance(fused_output, tuple)
else: else:
@ -1744,9 +1745,10 @@ class FusedMoE(CustomOp):
else: else:
return reduce_output(fused_output)[..., :og_hidden_states] return reduce_output(fused_output)[..., :og_hidden_states]
else: else:
if current_platform.is_tpu(): if current_platform.is_tpu() or current_platform.is_cpu():
# TODO: Once the OOM issue for the TPU backend is resolved, we # TODO: Once the OOM issue for the TPU backend is resolved, we
# will switch to using the moe_forward custom op. # will switch to using the moe_forward custom op.
# Note: CPU doesn't require wrapped forward_impl.
shared_output, fused_output = self.forward_impl( shared_output, fused_output = self.forward_impl(
hidden_states, router_logits hidden_states, router_logits
) )