diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test.sh b/.buildkite/scripts/hardware_ci/run-cpu-test.sh index 438fe522c8702..471c8616df85c 100644 --- a/.buildkite/scripts/hardware_ci/run-cpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test.sh @@ -50,6 +50,7 @@ function cpu_tests() { docker exec cpu-test-"$NUMA_NODE" bash -c " set -e pytest -x -v -s tests/kernels/attention/test_cpu_attn.py + pytest -x -v -s tests/kernels/moe/test_cpu_fused_moe.py pytest -x -v -s tests/kernels/test_onednn.py" # Run basic model test diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 85b286f8d8d0a..0af87fd7f0b53 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -330,7 +330,7 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON PUBLIC ${oneDNN_BINARY_DIR}/include PRIVATE ${oneDNN_SOURCE_DIR}/src ) - target_link_libraries(dnnl_ext dnnl) + target_link_libraries(dnnl_ext dnnl torch) target_compile_options(dnnl_ext PRIVATE ${CXX_COMPILE_FLAGS} -fPIC) list(APPEND LIBS dnnl_ext) set(USE_ONEDNN ON) @@ -358,13 +358,13 @@ set(VLLM_EXT_SRC "csrc/cpu/pos_encoding.cpp" "csrc/moe/dynamic_4bit_int_moe_cpu.cpp" "csrc/cpu/cpu_attn.cpp" - "csrc/cpu/scratchpad_manager.cpp" "csrc/cpu/torch_bindings.cpp") if (AVX512_FOUND AND NOT AVX512_DISABLED) set(VLLM_EXT_SRC "csrc/cpu/shm.cpp" "csrc/cpu/cpu_wna16.cpp" + "csrc/cpu/cpu_fused_moe.cpp" ${VLLM_EXT_SRC}) if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI) set(VLLM_EXT_SRC diff --git a/csrc/cpu/cpu_attn_macros.h b/csrc/cpu/cpu_arch_macros.h similarity index 97% rename from csrc/cpu/cpu_attn_macros.h rename to csrc/cpu/cpu_arch_macros.h index 35716a0790ab3..c73b62ecdec90 100644 --- a/csrc/cpu/cpu_attn_macros.h +++ b/csrc/cpu/cpu_arch_macros.h @@ -1,5 +1,5 @@ -#ifndef CPU_ATTN_MACROS_H -#define CPU_ATTN_MACROS_H +#ifndef CPU_ARCH_MACROS_H +#define CPU_ARCH_MACROS_H // x86_64 #ifdef __x86_64__ @@ -26,7 +26,7 @@ _mm512_castsi512_ps(_mm512_set1_epi32(0x42b17218)); \ const __m512i vec_127 = _mm512_set1_epi32(0x0000007f); \ const int n_mantissa_bits = 23; \ - auto fast_exp = [&](vec_op::FP32Vec16& vec) __attribute__(( \ + auto fast_exp = [&](const vec_op::FP32Vec16& vec) __attribute__(( \ always_inline)) { \ __m512 values = vec.reg; \ auto less_ln_flt_min_mask = \ @@ -98,7 +98,7 @@ poly = vbslq_f32(hi_mask, inf, poly); \ return vbslq_f32(lo_mask, zero, poly); \ }; \ - auto fast_exp = [&](vec_op::FP32Vec16& vec) \ + auto fast_exp = [&](const vec_op::FP32Vec16& vec) \ __attribute__((always_inline)) { \ float32x4x4_t result; \ result.val[0] = neon_expf(vec.reg.val[0]); \ @@ -110,4 +110,4 @@ #endif // __aarch64__ -#endif \ No newline at end of file +#endif diff --git a/csrc/cpu/cpu_attn_impl.hpp b/csrc/cpu/cpu_attn_impl.hpp index e3e077b845f4f..08d208e05a62c 100644 --- a/csrc/cpu/cpu_attn_impl.hpp +++ b/csrc/cpu/cpu_attn_impl.hpp @@ -8,10 +8,8 @@ #include #endif -#include "cpu_types.hpp" -#include "scratchpad_manager.h" -#include "cpu_attn_macros.h" -#include "utils.hpp" +#include "cpu/cpu_arch_macros.h" +#include "cpu/utils.hpp" namespace cpu_attention { enum class ISA { AMX, VEC, VEC16, NEON }; @@ -378,12 +376,13 @@ class AttentionScheduler { static constexpr int32_t MaxQTileIterNum = 128; - AttentionScheduler() : available_cache_size_(get_available_l2_size()) {} + AttentionScheduler() + : available_cache_size_(cpu_utils::get_available_l2_size()) {} torch::Tensor schedule(const ScheduleInput& input) const { const bool casual = input.casual; const int32_t thread_num = omp_get_max_threads(); - const int64_t cache_size = get_available_l2_size(); + const int64_t cache_size = cpu_utils::get_available_l2_size(); const int32_t max_num_q_per_iter = input.max_num_q_per_iter; const int32_t kv_len_alignment = input.kv_block_alignment; int32_t q_head_per_kv = input.num_heads_q / input.num_heads_kv; @@ -659,7 +658,7 @@ class AttentionScheduler { metadata_ptr->thread_num + metadata_ptr->reduction_scratchpad_size_per_kv_head * (use_gqa ? input.num_heads_kv : input.num_heads_q); - DNNLScratchPadManager::get_dnnl_scratchpad_manager()->realloc( + cpu_utils::ScratchPadManager::get_scratchpad_manager()->realloc( scratchpad_size); // metadata_ptr->print(); @@ -667,7 +666,7 @@ class AttentionScheduler { // test out of boundary access // { // float* cache_ptr = - // DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data(); + // cpu_utils::ScratchPadManager::getl_scratchpad_manager()->get_data(); // for (int64_t i = 0; i < scratchpad_size / sizeof(float); ++i) { // cache_ptr[i] = std::numeric_limits::quiet_NaN(); // } @@ -749,27 +748,6 @@ class AttentionScheduler { return std::max(rounded_tile_size, round_size); } - static int64_t get_available_l2_size() { - static int64_t size = []() { -#if defined(__APPLE__) - // macOS doesn't have _SC_LEVEL2_CACHE_SIZE. Use sysctlbyname. - int64_t l2_cache_size = 0; - size_t len = sizeof(l2_cache_size); - if (sysctlbyname("hw.l2cachesize", &l2_cache_size, &len, NULL, 0) == 0 && - l2_cache_size > 0) { - return l2_cache_size >> 1; // use 50% of L2 cache - } - // Fallback if sysctlbyname fails - return 128LL * 1024 >> 1; // use 50% of 128KB -#else - long l2_cache_size = sysconf(_SC_LEVEL2_CACHE_SIZE); - TORCH_CHECK_NE(l2_cache_size, -1); - return l2_cache_size >> 1; // use 50% of L2 cache -#endif - }(); - return size; - } - private: int64_t available_cache_size_; }; @@ -1402,7 +1380,7 @@ class AttentionMainLoop { // init buffers void* scratchpad_ptr = - DNNLScratchPadManager::get_dnnl_scratchpad_manager() + cpu_utils::ScratchPadManager::get_scratchpad_manager() ->get_data(); AttentionScratchPad buffer_manager(thread_id, metadata, scratchpad_ptr); @@ -1422,8 +1400,7 @@ class AttentionMainLoop { } } - const int64_t available_cache_size = - AttentionScheduler::get_available_l2_size(); + const int64_t available_cache_size = cpu_utils::get_available_l2_size(); const int32_t default_tile_size = AttentionScheduler::calcu_default_tile_size( available_cache_size, head_dim, sizeof(kv_cache_t), diff --git a/csrc/cpu/cpu_fused_moe.cpp b/csrc/cpu/cpu_fused_moe.cpp new file mode 100644 index 0000000000000..090e2d4cd4b56 --- /dev/null +++ b/csrc/cpu/cpu_fused_moe.cpp @@ -0,0 +1,727 @@ +#include "cpu/cpu_types.hpp" +#include "cpu/utils.hpp" +#include "cpu/micro_gemm/cpu_micro_gemm_vec.hpp" +#include "cpu/cpu_arch_macros.h" + +#ifdef CPU_CAPABILITY_AMXBF16 + #include "cpu/micro_gemm/cpu_micro_gemm_amx.hpp" + #define AMX_DISPATCH(...) \ + case cpu_utils::ISA::AMX: { \ + using gemm_t = cpu_micro_gemm::MicroGemm; \ + return __VA_ARGS__(); \ + } +#else + #define AMX_DISPATCH(...) case cpu_utils::ISA::AMX: +#endif + +#define CPU_ISA_DISPATCH_IMPL(ISA_TYPE, ...) \ + [&] { \ + switch (ISA_TYPE) { \ + AMX_DISPATCH(__VA_ARGS__) \ + case cpu_utils::ISA::VEC: { \ + using gemm_t = \ + cpu_micro_gemm::MicroGemm; \ + return __VA_ARGS__(); \ + } \ + default: { \ + TORCH_CHECK(false, "Invalid CPU ISA type."); \ + } \ + } \ + }() + +namespace { +enum class FusedMOEAct { SiluAndMul, SwigluOAIAndMul }; + +FusedMOEAct get_act_type(const std::string& act) { + if (act == "silu") { + return FusedMOEAct::SiluAndMul; + } else if (act == "swigluoai") { + return FusedMOEAct::SwigluOAIAndMul; + } else { + TORCH_CHECK(false, "Invalid act type: " + act); + } +} + +template +void swigluoai_and_mul(float* __restrict__ input, scalar_t* __restrict__ output, + const int32_t m_size, const int32_t n_size, + const int32_t input_stride, + const int32_t output_stride) { + using scalar_vec_t = typename cpu_utils::VecTypeTrait::vec_t; + // For GPT-OSS interleaved gate-up weights + alignas(64) static int32_t index[16] = {0, 2, 4, 6, 8, 10, 12, 14, + 16, 18, 20, 22, 24, 26, 28, 30}; + vec_op::INT32Vec16 index_vec(index); + vec_op::FP32Vec16 gate_up_max_vec(7.0); + vec_op::FP32Vec16 up_min_vec(-7.0); + vec_op::FP32Vec16 alpha_vec(1.702); + vec_op::FP32Vec16 one_vec(1.0); + + DEFINE_FAST_EXP + + for (int32_t m = 0; m < m_size; ++m) { + for (int32_t n = 0; n < n_size; n += 32) { + vec_op::FP32Vec16 gate_vec(input + n, index_vec); + vec_op::FP32Vec16 up_vec(input + n + 1, index_vec); + gate_vec = gate_vec.min(gate_up_max_vec); + up_vec = up_vec.clamp(up_min_vec, gate_up_max_vec); + auto sigmoid_vec = one_vec / (one_vec + fast_exp(-gate_vec * alpha_vec)); + auto glu = gate_vec * sigmoid_vec; + auto gated_output_fp32 = (one_vec + up_vec) * glu; + scalar_vec_t gated_output = scalar_vec_t(gated_output_fp32); + gated_output.save(output + n / 2); + } + input += input_stride; + output += output_stride; + } +} + +template +void silu_and_mul(float* __restrict__ input, scalar_t* __restrict__ output, + const int32_t m_size, const int32_t n_size, + const int32_t input_stride, const int32_t output_stride) { + using scalar_vec_t = typename cpu_utils::VecTypeTrait::vec_t; + const int32_t dim = n_size / 2; + float* __restrict__ gate = input; + float* __restrict__ up = input + dim; + vec_op::FP32Vec16 one_vec(1.0); + + DEFINE_FAST_EXP + + for (int32_t m = 0; m < m_size; ++m) { + for (int32_t n = 0; n < dim; n += 16) { + vec_op::FP32Vec16 gate_vec(gate + n); + vec_op::FP32Vec16 up_vec(up + n); + auto sigmoid_vec = one_vec / (one_vec + fast_exp(-gate_vec)); + auto silu = gate_vec * sigmoid_vec; + auto gated_output_fp32 = up_vec * silu; + scalar_vec_t gated_output = scalar_vec_t(gated_output_fp32); + gated_output.save(output + n); + } + gate += input_stride; + up += input_stride; + output += output_stride; + } +} + +template +FORCE_INLINE void apply_gated_act(const FusedMOEAct act, + float* __restrict__ input, + scalar_t* __restrict__ output, + const int32_t m, const int32_t n, + const int32_t input_stride, + const int32_t output_stride) { + switch (act) { + case FusedMOEAct::SwigluOAIAndMul: + swigluoai_and_mul(input, output, m, n, input_stride, output_stride); + return; + case FusedMOEAct::SiluAndMul: + silu_and_mul(input, output, m, n, input_stride, output_stride); + return; + default: + TORCH_CHECK(false, "Unsupported act type."); + } +} + +template +void prepack_moe_weight_impl(scalar_t* __restrict__ weight_ptr, + scalar_t* __restrict__ packed_weight_ptr, + const int32_t expert_num, + const int32_t output_size, + const int32_t input_size, + const int64_t expert_stride) { +#pragma omp parallel for + for (int32_t e_idx = 0; e_idx < expert_num; ++e_idx) { + gemm_t::pack_weight(weight_ptr + expert_stride * e_idx, + packed_weight_ptr + expert_stride * e_idx, output_size, + input_size); + } +} + +template +void fused_moe_impl(scalar_t* __restrict__ output, scalar_t* __restrict__ input, + w_t* __restrict__ w13, w_t* __restrict__ w2, + w_t* __restrict__ w13_bias, w_t* __restrict__ w2_bias, + float* __restrict__ topk_weights, + int32_t* __restrict__ topk_id, FusedMOEAct act_type, + const int32_t token_num, const int32_t expert_num, + const int32_t topk_num, const int32_t input_size_13, + const int32_t output_size_13, const int32_t input_size_2, + const int32_t output_size_2) { + using scalar_vec_t = typename cpu_utils::VecTypeTrait::vec_t; + constexpr int32_t gemm_n_tile_size = gemm_t::NSize; + constexpr int32_t gemm_m_tile_size = gemm_t::MaxMSize; + constexpr int32_t min_w13_n_tile_size = 2 * gemm_n_tile_size; + static_assert(gemm_n_tile_size % 16 == 0); + + TORCH_CHECK_EQ(output_size_13 % min_w13_n_tile_size, 0); + TORCH_CHECK_EQ(output_size_2 % gemm_n_tile_size, 0); + TORCH_CHECK_EQ(output_size_13 / 2, input_size_2); + + const int32_t thread_num = omp_get_max_threads(); + + const int32_t w13_input_buffer_size = cpu_utils::round_up<64>( + gemm_m_tile_size * input_size_13 * sizeof(scalar_t)); + + const int32_t w13_n_tile_size = [&]() { + const int64_t cache_size = cpu_utils::get_available_l2_size(); + // input buffer + output buffer + weight + const int32_t n_size_cache_limit = + (cache_size - w13_input_buffer_size) / + (gemm_m_tile_size * sizeof(float) + input_size_13 * sizeof(scalar_t)); + const int32_t n_size_thread_limit = + output_size_13 / std::max(1, thread_num / topk_num); + const int32_t n_size = cpu_utils::round_down( + std::min(n_size_cache_limit, n_size_thread_limit)); + return std::max(n_size, min_w13_n_tile_size); + }(); + + const int32_t w2_input_tile_size = cpu_utils::round_up<64>( + gemm_m_tile_size * input_size_2 * sizeof(scalar_t)); + + const int32_t w2_n_tile_size = [&]() { + const int64_t cache_size = cpu_utils::get_available_l2_size(); + // input tile + weight + const int32_t n_size_cache_limit = + (cache_size - w2_input_tile_size) / (input_size_2 * sizeof(scalar_t)); + const int32_t n_size_thread_limit = + output_size_2 / std::max(1, thread_num / topk_num); + const int32_t n_size = cpu_utils::round_down( + std::min(n_size_cache_limit, n_size_thread_limit)); + return std::max(n_size, gemm_n_tile_size); + }(); + + // allocate buffers + int32_t common_buffer_offset = 0; + int32_t w13_thread_buffer_offset = 0; + int32_t ws_thread_buffer_offset = 0; + + // common buffers + const int32_t token_num_per_group_buffer_size = + cpu_utils::round_up<64>(expert_num * sizeof(int32_t)); + const int32_t token_num_per_group_buffer_offset = common_buffer_offset; + common_buffer_offset += token_num_per_group_buffer_size; + + const int32_t cu_token_num_per_group_buffer_size = + cpu_utils::round_up<64>((expert_num + 1) * sizeof(int32_t)); + const int32_t cu_token_num_per_group_buffer_offset = common_buffer_offset; + common_buffer_offset += cu_token_num_per_group_buffer_size; + + const int32_t expand_token_id_buffer_size = + cpu_utils::round_up<64>(token_num * topk_num * sizeof(int32_t)); + const int32_t expand_token_id_buffer_offset = common_buffer_offset; + common_buffer_offset += expand_token_id_buffer_size; + + const int32_t expand_token_id_index_buffer_size = + cpu_utils::round_up<64>(token_num * topk_num * sizeof(int32_t)); + const int32_t expand_token_id_index_buffer_offset = common_buffer_offset; + common_buffer_offset += expand_token_id_index_buffer_size; + + const int32_t w13_gemm_output_buffer_size = cpu_utils::round_up<64>( + token_num * topk_num * (output_size_13 / 2) * sizeof(scalar_t)); + const int32_t w13_gemm_output_buffer_offset = common_buffer_offset; + common_buffer_offset += w13_gemm_output_buffer_size; + + const int32_t w2_gemm_output_buffer_size = cpu_utils::round_up<64>( + token_num * topk_num * output_size_2 * sizeof(float)); + const int32_t w2_gemm_output_buffer_offset = common_buffer_offset; + common_buffer_offset += w2_gemm_output_buffer_size; + + // w13 GEMM thread buffers + const int32_t w13_input_buffer_offset = w13_thread_buffer_offset; + w13_thread_buffer_offset += w13_input_buffer_size; + + const int32_t w13_output_buffer_size = cpu_utils::round_up<64>( + gemm_m_tile_size * w13_n_tile_size * sizeof(float)); + const int32_t w13_output_buffer_offset = w13_thread_buffer_offset; + w13_thread_buffer_offset += w13_output_buffer_size; + + // Weighted sum thread buffer + const int32_t ws_output_buffer_size = + cpu_utils::round_up<64>(output_size_2 * sizeof(float)); + const int32_t ws_output_buffer_offset = ws_thread_buffer_offset; + ws_thread_buffer_offset += ws_output_buffer_size; + + const int32_t buffer_size = + common_buffer_offset + + std::max(w13_thread_buffer_offset, ws_thread_buffer_offset) * thread_num; + cpu_utils::ScratchPadManager::get_scratchpad_manager()->realloc(buffer_size); + uint8_t* common_buffer_start = + cpu_utils::ScratchPadManager::get_scratchpad_manager() + ->get_data(); + uint8_t* thread_buffer_start = common_buffer_start + common_buffer_offset; + + int32_t* __restrict__ token_num_per_group_buffer = reinterpret_cast( + common_buffer_start + token_num_per_group_buffer_offset); + int32_t* __restrict__ cu_token_num_per_group_buffer = + reinterpret_cast(common_buffer_start + + cu_token_num_per_group_buffer_offset); + int32_t* __restrict__ expand_token_id_buffer = reinterpret_cast( + common_buffer_start + expand_token_id_buffer_offset); + int32_t* __restrict__ expand_token_id_index_buffer = + reinterpret_cast(common_buffer_start + + expand_token_id_index_buffer_offset); + + // prepare token-expert mappings + { + std::memset(token_num_per_group_buffer, 0, expert_num * sizeof(int32_t)); + for (int32_t i = 0; i < token_num * topk_num; ++i) { + int32_t curr_expert_id = topk_id[i]; + ++token_num_per_group_buffer[curr_expert_id]; + } + + int32_t token_num_sum = 0; + cu_token_num_per_group_buffer[0] = 0; + int32_t* token_index_buffer = cu_token_num_per_group_buffer + 1; + for (int32_t i = 0; i < expert_num; ++i) { + token_index_buffer[i] = token_num_sum; + token_num_sum += token_num_per_group_buffer[i]; + } + + for (int32_t i = 0; i < token_num; ++i) { + int32_t* curr_topk_id = topk_id + i * topk_num; + int32_t* curr_index_buffer = expand_token_id_index_buffer + i * topk_num; + for (int32_t j = 0; j < topk_num; ++j) { + int32_t curr_expert_id = curr_topk_id[j]; + int32_t curr_index = token_index_buffer[curr_expert_id]; + ++token_index_buffer[curr_expert_id]; + expand_token_id_buffer[curr_index] = i; + curr_index_buffer[j] = curr_index; + } + } + } + + // w13 GEMM + act + { + alignas(64) cpu_utils::Counter counter; + cpu_utils::Counter* counter_ptr = &counter; + +#pragma omp parallel for schedule(static, 1) + for (int32_t thread_id = 0; thread_id < thread_num; ++thread_id) { + const int32_t task_num_per_expert = + (output_size_13 + w13_n_tile_size - 1) / w13_n_tile_size; + const int32_t task_num = task_num_per_expert * expert_num; + + uint8_t* __restrict__ thread_buffer = + thread_buffer_start + thread_id * w13_thread_buffer_offset; + scalar_t* __restrict__ w13_input_buffer = + reinterpret_cast(thread_buffer + w13_input_buffer_offset); + float* __restrict__ w13_output_buffer = + reinterpret_cast(thread_buffer + w13_output_buffer_offset); + scalar_t* __restrict__ w13_gemm_output_buffer = + reinterpret_cast(common_buffer_start + + w13_gemm_output_buffer_offset); + + gemm_t gemm; + + const int32_t input_size_13_bytes = input_size_13 * sizeof(scalar_t); + const int32_t w13_n_group_stride = 16 * input_size_13; + const int32_t w13_n_tile_stride = gemm_n_tile_size * input_size_13; + + for (;;) { + int32_t task_id = counter_ptr->acquire_counter(); + if (task_id >= task_num) { + break; + } + + const int32_t curr_expert_id = task_id / task_num_per_expert; + const int32_t curr_output_group_id = task_id % task_num_per_expert; + const int32_t curr_token_num = + token_num_per_group_buffer[curr_expert_id]; + if (curr_token_num == 0) { + continue; + } + + const int32_t actual_n_tile_size = + std::min(w13_n_tile_size, + output_size_13 - curr_output_group_id * w13_n_tile_size); + const int32_t* __restrict__ curr_expand_token_id_buffer = + expand_token_id_buffer + + cu_token_num_per_group_buffer[curr_expert_id]; + scalar_t* __restrict__ curr_w13_gemm_output_buffer = + w13_gemm_output_buffer + + cu_token_num_per_group_buffer[curr_expert_id] * + (output_size_13 / 2) + + curr_output_group_id * w13_n_tile_size / 2; + + w_t* __restrict__ w13_weight_ptr_0 = nullptr; + w_t* __restrict__ w13_weight_ptr_1 = nullptr; + w_t* __restrict__ w13_bias_ptr_0 = nullptr; + w_t* __restrict__ w13_bias_ptr_1 = nullptr; + if (act_type == FusedMOEAct::SwigluOAIAndMul) { + // For SwigluOAIAndMul, up and down weights are interleaved + w13_weight_ptr_0 = + w13 + curr_expert_id * input_size_13 * output_size_13 + + curr_output_group_id * w13_n_tile_size * input_size_13; + w13_weight_ptr_1 = + w13_weight_ptr_0 + actual_n_tile_size / 2 * input_size_13; + if (w13_bias != nullptr) { + w13_bias_ptr_0 = w13_bias + curr_expert_id * output_size_13 + + curr_output_group_id * w13_n_tile_size; + w13_bias_ptr_1 = w13_bias_ptr_0 + actual_n_tile_size / 2; + } + } else { + w13_weight_ptr_0 = + w13 + curr_expert_id * input_size_13 * output_size_13 + + curr_output_group_id * (w13_n_tile_size / 2) * input_size_13; + w13_weight_ptr_1 = + w13_weight_ptr_0 + output_size_13 / 2 * input_size_13; + if (w13_bias != nullptr) { + w13_bias_ptr_0 = w13_bias + curr_expert_id * output_size_13 + + curr_output_group_id * (w13_n_tile_size / 2); + w13_bias_ptr_1 = w13_bias_ptr_0 + output_size_13 / 2; + } + } + + scalar_t* __restrict__ curr_w13_input_buffer = w13_input_buffer; + for (int32_t token_idx = 0; token_idx < curr_token_num; + token_idx += gemm_m_tile_size) { + const int32_t actual_token_num = + std::min(gemm_m_tile_size, curr_token_num - token_idx); + // copy inputs + { + scalar_t* __restrict__ curr_w13_input_buffer_iter = + curr_w13_input_buffer; + for (int32_t i = 0; i < actual_token_num; ++i) { + const int32_t curr_token_id = curr_expand_token_id_buffer[i]; + int8_t* __restrict__ curr_input_iter = reinterpret_cast( + input + curr_token_id * input_size_13); + int8_t* __restrict__ curr_output_iter = + reinterpret_cast(curr_w13_input_buffer_iter); + int32_t j = 0; + for (; j < input_size_13_bytes - 64; j += 64) { + vec_op::INT8Vec64 vec(curr_input_iter); + vec.save(curr_output_iter); + curr_input_iter += 64; + curr_output_iter += 64; + } + vec_op::INT8Vec64 vec(curr_input_iter); + vec.save(curr_output_iter, input_size_13_bytes - j); + + // update + curr_w13_input_buffer_iter += input_size_13; + } + // update + curr_expand_token_id_buffer += actual_token_num; + } + + // gemm + act + { + scalar_t* __restrict__ w13_weight_ptr_0_iter = w13_weight_ptr_0; + scalar_t* __restrict__ w13_weight_ptr_1_iter = w13_weight_ptr_1; + scalar_t* __restrict__ w13_bias_ptr_0_iter = w13_bias_ptr_0; + scalar_t* __restrict__ w13_bias_ptr_1_iter = w13_bias_ptr_1; + scalar_t* __restrict__ curr_w13_input_buffer_iter = + curr_w13_input_buffer; + float* __restrict__ w13_output_buffer_0_iter = w13_output_buffer; + float* __restrict__ w13_output_buffer_1_iter = + w13_output_buffer + actual_n_tile_size / 2; + for (int32_t i = 0; i < actual_n_tile_size; + i += min_w13_n_tile_size) { + gemm.gemm(curr_w13_input_buffer_iter, w13_weight_ptr_0_iter, + w13_output_buffer_0_iter, actual_token_num, + input_size_13, input_size_13, w13_n_group_stride, + actual_n_tile_size, false); + + if (w13_bias != nullptr) { + cpu_micro_gemm::add_bias_epilogue( + w13_output_buffer_0_iter, w13_output_buffer_0_iter, + w13_bias_ptr_0_iter, actual_token_num, actual_n_tile_size, + actual_n_tile_size); + w13_bias_ptr_0_iter += gemm_n_tile_size; + } + + gemm.gemm(curr_w13_input_buffer_iter, w13_weight_ptr_1_iter, + w13_output_buffer_1_iter, actual_token_num, + input_size_13, input_size_13, w13_n_group_stride, + actual_n_tile_size, false); + + if (w13_bias != nullptr) { + cpu_micro_gemm::add_bias_epilogue( + w13_output_buffer_1_iter, w13_output_buffer_1_iter, + w13_bias_ptr_1_iter, actual_token_num, actual_n_tile_size, + actual_n_tile_size); + w13_bias_ptr_1_iter += gemm_n_tile_size; + } + + // update + w13_weight_ptr_0_iter += w13_n_tile_stride; + w13_weight_ptr_1_iter += w13_n_tile_stride; + w13_output_buffer_0_iter += gemm_n_tile_size; + w13_output_buffer_1_iter += gemm_n_tile_size; + } + + apply_gated_act(act_type, w13_output_buffer, + curr_w13_gemm_output_buffer, actual_token_num, + actual_n_tile_size, actual_n_tile_size, + output_size_13 / 2); + + // update + curr_w13_gemm_output_buffer += + gemm_m_tile_size * (output_size_13 / 2); + } + } + } + } + } + + // w2 GEMM + { + alignas(64) cpu_utils::Counter counter; + cpu_utils::Counter* counter_ptr = &counter; + +#pragma omp parallel for schedule(static, 1) + for (int32_t thread_id = 0; thread_id < thread_num; ++thread_id) { + const int32_t task_num_per_expert = + (output_size_2 + w2_n_tile_size - 1) / w2_n_tile_size; + const int32_t task_num = task_num_per_expert * expert_num; + scalar_t* __restrict__ w13_gemm_output_buffer = + reinterpret_cast(common_buffer_start + + w13_gemm_output_buffer_offset); + float* __restrict__ w2_gemm_output_buffer = reinterpret_cast( + common_buffer_start + w2_gemm_output_buffer_offset); + + gemm_t gemm; + + const int32_t w2_n_tile_stride = gemm_n_tile_size * input_size_2; + const int32_t w2_n_group_stride = 16 * input_size_2; + + for (;;) { + int32_t task_id = counter_ptr->acquire_counter(); + if (task_id >= task_num) { + break; + } + + const int32_t curr_expert_id = task_id / task_num_per_expert; + const int32_t curr_output_group_id = task_id % task_num_per_expert; + const int32_t curr_token_num = + token_num_per_group_buffer[curr_expert_id]; + if (curr_token_num == 0) { + continue; + } + + const int32_t actual_n_tile_size = + std::min(w2_n_tile_size, + output_size_2 - curr_output_group_id * w2_n_tile_size); + scalar_t* __restrict__ curr_w13_gemm_output_buffer = + w13_gemm_output_buffer + + cu_token_num_per_group_buffer[curr_expert_id] * input_size_2; + float* __restrict__ curr_w2_gemm_output_buffer = + w2_gemm_output_buffer + + cu_token_num_per_group_buffer[curr_expert_id] * output_size_2 + + curr_output_group_id * w2_n_tile_size; + scalar_t* __restrict__ w2_weight_ptr = + w2 + curr_expert_id * output_size_2 * input_size_2 + + curr_output_group_id * w2_n_tile_size * input_size_2; + scalar_t* __restrict__ w2_bias_ptr = nullptr; + if (w2_bias != nullptr) { + w2_bias_ptr = w2_bias + curr_expert_id * output_size_2 + + curr_output_group_id * w2_n_tile_size; + } + + for (int32_t token_idx = 0; token_idx < curr_token_num; + token_idx += gemm_m_tile_size) { + const int32_t actual_token_num = + std::min(gemm_m_tile_size, curr_token_num - token_idx); + + scalar_t* __restrict__ w2_weight_ptr_iter = w2_weight_ptr; + scalar_t* __restrict__ w2_bias_ptr_iter = w2_bias_ptr; + float* __restrict__ curr_w2_gemm_output_buffer_iter = + curr_w2_gemm_output_buffer; + for (int32_t i = 0; i < actual_n_tile_size; i += gemm_n_tile_size) { + gemm.gemm(curr_w13_gemm_output_buffer, w2_weight_ptr_iter, + curr_w2_gemm_output_buffer_iter, actual_token_num, + input_size_2, input_size_2, w2_n_group_stride, + output_size_2, false); + + if (w2_bias != nullptr) { + cpu_micro_gemm::add_bias_epilogue( + curr_w2_gemm_output_buffer_iter, + curr_w2_gemm_output_buffer_iter, w2_bias_ptr_iter, + actual_token_num, output_size_2, output_size_2); + w2_bias_ptr_iter += gemm_n_tile_size; + } + + w2_weight_ptr_iter += w2_n_tile_stride; + curr_w2_gemm_output_buffer_iter += gemm_n_tile_size; + } + + // update + curr_w13_gemm_output_buffer += gemm_m_tile_size * input_size_2; + curr_w2_gemm_output_buffer += gemm_m_tile_size * output_size_2; + } + } + } + } + + // weighted sum + { + alignas(64) cpu_utils::Counter counter; + cpu_utils::Counter* counter_ptr = &counter; + +#pragma omp parallel for schedule(static, 1) + for (int32_t thread_id = 0; thread_id < thread_num; ++thread_id) { + const int32_t task_num = token_num; + uint8_t* __restrict__ thread_buffer = + thread_buffer_start + thread_id * ws_thread_buffer_offset; + float* __restrict__ ws_output_buffer = + reinterpret_cast(thread_buffer + ws_output_buffer_offset); + float* __restrict__ w2_gemm_output_buffer = reinterpret_cast( + common_buffer_start + w2_gemm_output_buffer_offset); + + for (;;) { + int32_t task_id = counter_ptr->acquire_counter(); + if (task_id >= task_num) { + break; + } + + int32_t token_id = task_id; + int32_t* __restrict__ curr_expand_token_id_index_buffer = + expand_token_id_index_buffer + token_id * topk_num; + float* __restrict__ curr_weight = topk_weights + token_id * topk_num; + scalar_t* __restrict__ curr_output_buffer = + output + token_id * output_size_2; + + if (topk_num > 1) { + { + int32_t w2_output_idx = curr_expand_token_id_index_buffer[0]; + float* __restrict__ w2_output_iter = + w2_gemm_output_buffer + w2_output_idx * output_size_2; + float* __restrict__ ws_output_buffer_iter = ws_output_buffer; + vec_op::FP32Vec16 weight_vec(curr_weight[0]); + for (int32_t i = 0; i < output_size_2; i += 16) { + vec_op::FP32Vec16 vec(w2_output_iter); + vec = vec * weight_vec; + vec.save(ws_output_buffer_iter); + + // update + w2_output_iter += 16; + ws_output_buffer_iter += 16; + } + } + + { + for (int32_t idx = 1; idx < topk_num - 1; ++idx) { + int32_t w2_output_idx = curr_expand_token_id_index_buffer[idx]; + float* __restrict__ w2_output_iter = + w2_gemm_output_buffer + w2_output_idx * output_size_2; + float* __restrict__ ws_output_buffer_iter = ws_output_buffer; + vec_op::FP32Vec16 weight_vec(curr_weight[idx]); + for (int32_t i = 0; i < output_size_2; i += 16) { + vec_op::FP32Vec16 vec(w2_output_iter); + vec_op::FP32Vec16 sum(ws_output_buffer_iter); + sum = sum + vec * weight_vec; + sum.save(ws_output_buffer_iter); + + // update + w2_output_iter += 16; + ws_output_buffer_iter += 16; + } + } + } + + { + int32_t idx = topk_num - 1; + int32_t w2_output_idx = curr_expand_token_id_index_buffer[idx]; + float* __restrict__ w2_output_iter = + w2_gemm_output_buffer + w2_output_idx * output_size_2; + float* __restrict__ ws_output_buffer_iter = ws_output_buffer; + scalar_t* __restrict__ curr_output_buffer_iter = curr_output_buffer; + vec_op::FP32Vec16 weight_vec(curr_weight[idx]); + for (int32_t i = 0; i < output_size_2; i += 16) { + vec_op::FP32Vec16 vec(w2_output_iter); + vec_op::FP32Vec16 sum(ws_output_buffer_iter); + sum = sum + vec * weight_vec; + scalar_vec_t out_vec(sum); + out_vec.save(curr_output_buffer_iter); + + // update + w2_output_iter += 16; + ws_output_buffer_iter += 16; + curr_output_buffer_iter += 16; + } + } + } else { + int32_t w2_output_idx = curr_expand_token_id_index_buffer[0]; + float* __restrict__ w2_output_iter = + w2_gemm_output_buffer + w2_output_idx * output_size_2; + scalar_t* __restrict__ curr_output_buffer_iter = curr_output_buffer; + vec_op::FP32Vec16 weight_vec(curr_weight[0]); + for (int32_t i = 0; i < output_size_2; i += 16) { + vec_op::FP32Vec16 vec(w2_output_iter); + vec = vec * weight_vec; + scalar_vec_t out_vec(vec); + out_vec.save(curr_output_buffer_iter); + + // update + w2_output_iter += 16; + curr_output_buffer_iter += 16; + } + } + } + } + } +} +} // namespace + +void prepack_moe_weight( + const torch::Tensor& weight, // [expert_num, output_size, input_size] + torch::Tensor& packed_weight, const std::string& isa) { + TORCH_CHECK(weight.is_contiguous()); + const int32_t expert_num = weight.size(0); + const int32_t output_size = weight.size(1); + const int32_t input_size = weight.size(2); + TORCH_CHECK_EQ(output_size % 32, 0); + const int64_t expert_stride = weight.stride(0); + cpu_utils::ISA isa_type = cpu_utils::get_isa(isa); + + VLLM_DISPATCH_FLOATING_TYPES( + weight.scalar_type(), "prepack_moe_weight", [&]() { + CPU_ISA_DISPATCH_IMPL(isa_type, [&]() { + scalar_t* weight_ptr = weight.data_ptr(); + scalar_t* packed_weight_ptr = packed_weight.data_ptr(); + prepack_moe_weight_impl( + weight_ptr, packed_weight_ptr, expert_num, output_size, + input_size, expert_stride); + }); + }); +} + +void cpu_fused_moe( + torch::Tensor& output, // [token_num, output_size_2] + const torch::Tensor& input, // [token_num, input_size_13] + const torch::Tensor& + w13, // [expert_num, output_size_13, input_size_13], packed + const torch::Tensor& + w2, // [expert_num, output_size_2, input_size_2], packed + const std::optional& + w13_bias, // [expert_num, output_size_13] + const std::optional& w2_bias, // [expert_num, output_size_2] + const torch::Tensor& topk_weights, // [token_num, k], float32 + const torch::Tensor& topk_id, // [token_num, k], int32 + const std::string& act, const std::string& isa) { + const int32_t token_num = input.size(0); + const int32_t input_size_13 = input.size(1); + const int64_t input_stride = input.stride(0); + TORCH_CHECK_EQ(input_stride, input_size_13); + const int32_t expert_num = w13.size(0); + const int32_t output_size_13 = w13.size(1); + const int32_t input_size_2 = w2.size(2); + const int32_t output_size_2 = w2.size(1); + const int32_t topk_num = topk_id.size(1); + const FusedMOEAct act_type = get_act_type(act); + cpu_utils::ISA isa_type = cpu_utils::get_isa(isa); + + VLLM_DISPATCH_FLOATING_TYPES(w13.scalar_type(), "cpu_fused_moe", [&]() { + CPU_ISA_DISPATCH_IMPL(isa_type, [&]() { + fused_moe_impl( + output.data_ptr(), input.data_ptr(), + w13.data_ptr(), w2.data_ptr(), + w13_bias.has_value() ? w13_bias->data_ptr() : nullptr, + w2_bias.has_value() ? w2_bias->data_ptr() : nullptr, + topk_weights.data_ptr(), topk_id.data_ptr(), act_type, + token_num, expert_num, topk_num, input_size_13, output_size_13, + input_size_2, output_size_2); + }); + }); +} diff --git a/csrc/cpu/cpu_types_x86.hpp b/csrc/cpu/cpu_types_x86.hpp index 6f51277f78440..d94af338ac1c9 100644 --- a/csrc/cpu/cpu_types_x86.hpp +++ b/csrc/cpu/cpu_types_x86.hpp @@ -352,6 +352,10 @@ struct FP32Vec16 : public Vec { explicit FP32Vec16(bool, void* ptr) : reg((__m512)_mm512_stream_load_si512(ptr)) {} + // strided load + explicit FP32Vec16(const float* ptr, INT32Vec16 idx) + : reg(_mm512_i32gather_ps(idx.reg, ptr, 4)) {} + explicit FP32Vec16(__m512 data) : reg(data) {} // de-pack 4 bit values @@ -408,6 +412,10 @@ struct FP32Vec16 : public Vec { return FP32Vec16(_mm512_sub_ps(reg, b.reg)); } + FP32Vec16 operator-() const { + return FP32Vec16(_mm512_xor_ps(reg, _mm512_set1_ps(-0.0f))); + } + FP32Vec16 operator/(const FP32Vec16& b) const { return FP32Vec16(_mm512_div_ps(reg, b.reg)); } diff --git a/csrc/cpu/cpu_wna16.cpp b/csrc/cpu/cpu_wna16.cpp index 816d195506e52..88d48f3db8772 100644 --- a/csrc/cpu/cpu_wna16.cpp +++ b/csrc/cpu/cpu_wna16.cpp @@ -1,6 +1,5 @@ -#include "cpu_types.hpp" -#include "scratchpad_manager.h" -#include "utils.hpp" +#include "cpu/cpu_types.hpp" +#include "cpu/utils.hpp" #ifdef CPU_CAPABILITY_AMXBF16 #include "cpu/micro_gemm/cpu_micro_gemm_amx.hpp" @@ -158,7 +157,7 @@ void cpu_gemm_wna16_impl( // a simple schedule policy, just to hold more B tiles in L2 and make sure // each thread has tasks const int32_t n_partition_size = [&]() { - const int64_t cache_size = cpu_utils::get_l2_size(); + const int64_t cache_size = cpu_utils::get_available_l2_size(); int64_t ps_cache_limit = cache_size / (k_size * sizeof(scalar_t)); int64_t ps_thread_limit = n_size / thread_num; ps_cache_limit = @@ -179,8 +178,8 @@ void cpu_gemm_wna16_impl( const int64_t b_buffer_offset = 0; const int64_t c_buffer_offset = b_buffer_size; const int64_t buffer_size = b_buffer_size + c_buffer_size; - DNNLScratchPadManager::get_dnnl_scratchpad_manager()->realloc(buffer_size * - thread_num); + cpu_utils::ScratchPadManager::get_scratchpad_manager()->realloc(buffer_size * + thread_num); alignas(64) cpu_utils::Counter counter; cpu_utils::Counter* counter_ptr = &counter; @@ -190,9 +189,10 @@ void cpu_gemm_wna16_impl( scalar_t* __restrict__ b_buffer = nullptr; float* __restrict__ c_buffer = nullptr; { - uint8_t* buffer_ptr = DNNLScratchPadManager::get_dnnl_scratchpad_manager() - ->get_data() + - thread_id * buffer_size; + uint8_t* buffer_ptr = + cpu_utils::ScratchPadManager::get_scratchpad_manager() + ->get_data() + + thread_id * buffer_size; b_buffer = reinterpret_cast(buffer_ptr + b_buffer_offset); c_buffer = reinterpret_cast(buffer_ptr + c_buffer_offset); } diff --git a/csrc/cpu/dnnl_helper.cpp b/csrc/cpu/dnnl_helper.cpp index cfb6e78cba9a1..e337e10e1cf7b 100644 --- a/csrc/cpu/dnnl_helper.cpp +++ b/csrc/cpu/dnnl_helper.cpp @@ -4,8 +4,8 @@ #include "common/memory_desc.hpp" #include "common/memory.hpp" -#include "dnnl_helper.h" -#include "scratchpad_manager.h" +#include "cpu/utils.hpp" +#include "cpu/dnnl_helper.h" static dnnl::engine& default_engine() { static dnnl::engine engine(dnnl::engine::kind::cpu, 0); @@ -274,7 +274,7 @@ void W8A8MatMulPrimitiveHandler::execute(ExecArgs& args) { auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(5); scratchpad_storage->set_data_handle( - DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data()); + cpu_utils::ScratchPadManager::get_scratchpad_manager()->get_data()); matmul.execute(default_stream(), memory_cache_); default_stream().wait(); @@ -294,7 +294,7 @@ dnnl::matmul W8A8MatMulPrimitiveHandler::get_matmul_cache( return m_size_cache_->get_or_create(key, [&]() { dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false); - auto manager = DNNLScratchPadManager::get_dnnl_scratchpad_manager(); + auto manager = cpu_utils::ScratchPadManager::get_scratchpad_manager(); manager->realloc(desc.scratchpad_desc().get_size()); return dnnl::matmul(desc); }); @@ -470,7 +470,7 @@ void MatMulPrimitiveHandler::execute(ExecArgs& args) { auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(3); scratchpad_storage->set_data_handle( - DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data()); + cpu_utils::ScratchPadManager::get_scratchpad_manager()->get_data()); matmul.execute(default_stream(), memory_cache_); default_stream().wait(); @@ -486,7 +486,7 @@ dnnl::matmul MatMulPrimitiveHandler::get_matmul_cache( } return m_size_cache_->get_or_create(key, [&]() { dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false); - auto manager = DNNLScratchPadManager::get_dnnl_scratchpad_manager(); + auto manager = cpu_utils::ScratchPadManager::get_scratchpad_manager(); manager->realloc(desc.scratchpad_desc().get_size()); return dnnl::matmul(desc); }); diff --git a/csrc/cpu/micro_gemm/cpu_micro_gemm_amx.hpp b/csrc/cpu/micro_gemm/cpu_micro_gemm_amx.hpp index 87a019773a895..357c7cf1d7844 100644 --- a/csrc/cpu/micro_gemm/cpu_micro_gemm_amx.hpp +++ b/csrc/cpu/micro_gemm/cpu_micro_gemm_amx.hpp @@ -235,6 +235,39 @@ class MicroGemm { } } + static void pack_weight(const scalar_t* __restrict__ weight, + scalar_t* __restrict__ packed_weight, + const int32_t output_size, const int32_t input_size) { + constexpr int32_t elem_num_per_group = 4 / sizeof(scalar_t); + TORCH_CHECK_EQ(output_size % 16, 0); + TORCH_CHECK_EQ(input_size % (16 * elem_num_per_group), 0); + + const int32_t output_group_num = output_size / 16; + const int32_t input_32b_num = input_size / elem_num_per_group; + for (int32_t output_group_idx = 0; output_group_idx < output_group_num; + ++output_group_idx) { + const int32_t* __restrict__ weight_32b = + reinterpret_cast(weight); + int32_t* __restrict__ packed_weight_32b = + reinterpret_cast(packed_weight); + for (int32_t output_idx = 0; output_idx < 16; ++output_idx) { + for (int32_t weight_offset = 0, packed_offset = 0; + weight_offset < input_32b_num; + ++weight_offset, packed_offset += 16) { + packed_weight_32b[packed_offset] = weight_32b[weight_offset]; + } + + // update + weight_32b += input_32b_num; + packed_weight_32b += 1; + } + + // update + weight += 16 * input_size; + packed_weight += 16 * input_size; + } + } + private: alignas(64) __tilecfg amx_tile_config_; int32_t curr_m_; diff --git a/csrc/cpu/micro_gemm/cpu_micro_gemm_impl.hpp b/csrc/cpu/micro_gemm/cpu_micro_gemm_impl.hpp index 784da55a420e5..23e78a681b5fe 100644 --- a/csrc/cpu/micro_gemm/cpu_micro_gemm_impl.hpp +++ b/csrc/cpu/micro_gemm/cpu_micro_gemm_impl.hpp @@ -13,6 +13,9 @@ namespace cpu_micro_gemm { #define CPU_MICRO_GEMM_PARAMS \ a_ptr, b_ptr, c_ptr, m, k, lda, b_n_group_stride, ldc, accum_c +// Note: weights for MicroGemm should be packed as (output_size / 16) contiguous +// blocks, means the logical shape of blocks is [16, input_size]. And the actual +// layout of blocks can be ISA-specific. template class MicroGemm { public: @@ -86,6 +89,41 @@ FORCE_INLINE void bias_epilogue(float* __restrict__ c_ptr, curr_d += ldd; } } + +template +FORCE_INLINE void add_bias_epilogue(float* c_ptr, float* d_ptr, + scalar_t* __restrict__ bias_ptr, + const int32_t m, const int64_t ldc, + const int64_t ldd) { + using scalar_vec_t = typename cpu_utils::VecTypeTrait::vec_t; + static_assert(n_size % 16 == 0); + constexpr int32_t n_group_num = n_size / 16; + static_assert(n_group_num <= 16); + + vec_op::FP32Vec16 bias_vecs[n_group_num]; + scalar_t* __restrict__ curr_bias = bias_ptr; + vec_op::unroll_loop([&](int32_t i) { + scalar_vec_t vec(curr_bias); + bias_vecs[i] = vec_op::FP32Vec16(vec); + curr_bias += 16; + }); + + float* curr_c = c_ptr; + float* curr_d = d_ptr; + for (int32_t i = 0; i < m; ++i) { + float* curr_c_iter = curr_c; + float* curr_d_iter = curr_d; + vec_op::unroll_loop([&](int32_t n_g_idx) { + vec_op::FP32Vec16 c_vec_fp32(curr_c_iter); + c_vec_fp32 = c_vec_fp32 + bias_vecs[n_g_idx]; + c_vec_fp32.save(curr_d_iter); + curr_c_iter += 16; + curr_d_iter += 16; + }); + curr_c += ldc; + curr_d += ldd; + } +} } // namespace cpu_micro_gemm #endif diff --git a/csrc/cpu/micro_gemm/cpu_micro_gemm_vec.hpp b/csrc/cpu/micro_gemm/cpu_micro_gemm_vec.hpp index 3985c2f2e5fe4..bdd3e85a1c522 100644 --- a/csrc/cpu/micro_gemm/cpu_micro_gemm_vec.hpp +++ b/csrc/cpu/micro_gemm/cpu_micro_gemm_vec.hpp @@ -109,6 +109,25 @@ class MicroGemm { void gemm(DEFINE_CPU_MICRO_GEMM_PARAMS) { TileGemm82::gemm(CPU_MICRO_GEMM_PARAMS); } + + // Note: pack contiguous weight [output_size, input_size] as contiguous + // packed weight [output_size / 16, input_size, 16] + static void pack_weight(const scalar_t* __restrict__ weight, + scalar_t* __restrict__ packed_weight, + const int32_t output_size, const int32_t input_size) { + TORCH_CHECK_EQ(output_size % 16, 0); + for (int32_t o_idx = 0; o_idx < output_size; ++o_idx) { + const scalar_t* __restrict__ curr_weight = weight + o_idx * input_size; + scalar_t* __restrict__ curr_packed_weight = + packed_weight + (o_idx / 16) * (16 * input_size) + o_idx % 16; + for (int32_t i_idx = 0; i_idx < input_size; ++i_idx) { + *curr_packed_weight = *curr_weight; + + curr_packed_weight += 16; + ++curr_weight; + } + } + } }; } // namespace cpu_micro_gemm diff --git a/csrc/cpu/scratchpad_manager.cpp b/csrc/cpu/scratchpad_manager.cpp deleted file mode 100644 index 05cd435f34b7a..0000000000000 --- a/csrc/cpu/scratchpad_manager.cpp +++ /dev/null @@ -1,23 +0,0 @@ -#include - -#include "scratchpad_manager.h" - -DNNLScratchPadManager::DNNLScratchPadManager() : size_(0), ptr_(nullptr) { - this->realloc(allocation_unit * 128); -} - -void DNNLScratchPadManager::realloc(size_t new_size) { - new_size = round(new_size); - if (new_size > size_) { - if (ptr_ != nullptr) { - std::free(ptr_); - } - ptr_ = std::aligned_alloc(64, new_size); - size_ = new_size; - } -} - -DNNLScratchPadManager* DNNLScratchPadManager::get_dnnl_scratchpad_manager() { - static DNNLScratchPadManager manager; - return &manager; -} diff --git a/csrc/cpu/scratchpad_manager.h b/csrc/cpu/scratchpad_manager.h deleted file mode 100644 index 0ecf59192f845..0000000000000 --- a/csrc/cpu/scratchpad_manager.h +++ /dev/null @@ -1,31 +0,0 @@ -#ifndef SCRATCHPAD_MANAGER_H -#define SCRATCHPAD_MANAGER_H - -#include -#include - -class DNNLScratchPadManager { - public: - static constexpr size_t allocation_unit = 4 * 1024; // 4KB - - static DNNLScratchPadManager* get_dnnl_scratchpad_manager(); - - DNNLScratchPadManager(); - - template - T* get_data() { - return reinterpret_cast(ptr_); - } - - static size_t round(size_t size) { - return ((size + allocation_unit - 1) / allocation_unit) * allocation_unit; - } - - void realloc(size_t new_size); - - private: - size_t size_; - void* ptr_; -}; - -#endif diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index e0e3ef71b485f..dd419405c94b9 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -110,6 +110,17 @@ void cpu_gemm_wna16(const torch::Tensor& input, const torch::Tensor& q_weight, const std::optional& bias, const int64_t pack_factor, const std::string& isa_hint); +void prepack_moe_weight(const torch::Tensor& weight, + torch::Tensor& packed_weight, const std::string& isa); + +void cpu_fused_moe(torch::Tensor& output, const torch::Tensor& input, + const torch::Tensor& w13, const torch::Tensor& w2, + const std::optional& w13_bias, + const std::optional& w2_bias, + const torch::Tensor& topk_weights, + const torch::Tensor& topk_id, const std::string& act, + const std::string& isa); + TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // vLLM custom ops @@ -296,6 +307,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "pack_factor, str isa_hint) -> ()"); ops.impl("cpu_gemm_wna16", torch::kCPU, &cpu_gemm_wna16); #endif + + // fused moe +#if defined(__AVX512F__) + ops.def( + "prepack_moe_weight(Tensor weight, Tensor(a1!) packed_weight, str isa) " + "-> ()"); + ops.impl("prepack_moe_weight", torch::kCPU, &prepack_moe_weight); + ops.def( + "cpu_fused_moe(Tensor(a0!) output, Tensor input, Tensor w13, Tensor w2, " + "Tensor? w13_bias, Tensor? w2_bias, Tensor topk_weights, Tensor topk_id, " + "str act, str isa) -> ()"); + ops.impl("cpu_fused_moe", torch::kCPU, &cpu_fused_moe); +#endif } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) { diff --git a/csrc/cpu/utils.cpp b/csrc/cpu/utils.cpp index 3dacfc7b2b7a3..fcd7534ab4c5d 100644 --- a/csrc/cpu/utils.cpp +++ b/csrc/cpu/utils.cpp @@ -10,7 +10,7 @@ #define gettid() syscall(SYS_gettid) #endif -#include "cpu_types.hpp" +#include "cpu/utils.hpp" #ifdef VLLM_NUMA_DISABLED std::string init_cpu_threads_env(const std::string& cpu_ids) { @@ -138,4 +138,26 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) { return ss.str(); } + +namespace cpu_utils { +ScratchPadManager::ScratchPadManager() : size_(0), ptr_(nullptr) { + this->realloc(allocation_unit * 128); +} + +void ScratchPadManager::realloc(size_t new_size) { + new_size = round(new_size); + if (new_size > size_) { + if (ptr_ != nullptr) { + std::free(ptr_); + } + ptr_ = std::aligned_alloc(64, new_size); + size_ = new_size; + } +} + +ScratchPadManager* ScratchPadManager::get_scratchpad_manager() { + static ScratchPadManager manager; + return &manager; +} +} // namespace cpu_utils #endif diff --git a/csrc/cpu/utils.hpp b/csrc/cpu/utils.hpp index d3def306b8069..8ab0bb039c014 100644 --- a/csrc/cpu/utils.hpp +++ b/csrc/cpu/utils.hpp @@ -2,19 +2,24 @@ #define UTILS_HPP #include -#include -#include #include +#include -#if defined(__APPLE__) - #include -#endif - -#include "cpu_types.hpp" +#include "cpu/cpu_types.hpp" namespace cpu_utils { enum class ISA { AMX, VEC }; +inline ISA get_isa(const std::string& isa) { + if (isa == "amx") { + return ISA::AMX; + } else if (isa == "vec") { + return ISA::VEC; + } else { + TORCH_CHECK(false, "Invalid isa type: " + isa); + } +} + template struct VecTypeTrait { using vec_t = void; @@ -48,26 +53,66 @@ struct Counter { int64_t acquire_counter() { return counter++; } }; -inline int64_t get_l2_size() { +inline int64_t get_available_l2_size() { static int64_t size = []() { -#if defined(__APPLE__) - // macOS doesn't have _SC_LEVEL2_CACHE_SIZE. Use sysctlbyname. - int64_t l2_cache_size = 0; - size_t len = sizeof(l2_cache_size); - if (sysctlbyname("hw.l2cachesize", &l2_cache_size, &len, NULL, 0) == 0 && - l2_cache_size > 0) { - return l2_cache_size >> 1; // use 50% of L2 cache - } - // Fallback if sysctlbyname fails - return 128LL * 1024 >> 1; // use 50% of 128KB -#else - long l2_cache_size = sysconf(_SC_LEVEL2_CACHE_SIZE); - assert(l2_cache_size != -1); + const uint32_t l2_cache_size = at::cpu::L2_cache_size(); return l2_cache_size >> 1; // use 50% of L2 cache -#endif }(); return size; } + +template +inline T round_up(T size) { + T alignment = alignment_v; + return (((size + alignment - 1) / alignment) * alignment); +} + +template +inline T round_down(T size) { + T alignment = alignment_v; + return (size / alignment) * alignment; +} + +template +inline void print_logits(const char* name, T* ptr, int32_t row, int32_t col, + int32_t stride) { + std::stringstream ss; + ss << std::fixed << std::setprecision(5) << name << ": [\n"; + auto* curr_logits_buffer = ptr; + for (int32_t m = 0; m < row; ++m) { + for (int32_t n = 0; n < col; ++n) { + ss << curr_logits_buffer[n] << ", "; + } + ss << "\n"; + curr_logits_buffer += stride; + } + ss << "]\n"; + std::printf("%s", ss.str().c_str()); +} + +class ScratchPadManager { + public: + static constexpr size_t allocation_unit = 4 * 1024; // 4KB + + static ScratchPadManager* get_scratchpad_manager(); + + ScratchPadManager(); + + template + T* get_data() { + return reinterpret_cast(ptr_); + } + + static size_t round(size_t size) { + return ((size + allocation_unit - 1) / allocation_unit) * allocation_unit; + } + + void realloc(size_t new_size); + + private: + size_t size_; + void* ptr_; +}; } // namespace cpu_utils #endif diff --git a/docker/Dockerfile.cpu b/docker/Dockerfile.cpu index bd5bc43916eac..2caf1ad144178 100644 --- a/docker/Dockerfile.cpu +++ b/docker/Dockerfile.cpu @@ -147,7 +147,9 @@ WORKDIR /workspace/vllm RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ --mount=type=cache,target=/var/lib/apt,sharing=locked \ - apt-get install -y --no-install-recommends vim numactl xz-utils + apt-get install -y --no-install-recommends vim numactl xz-utils make clangd-14 + +RUN ln -s /usr/bin/clangd-14 /usr/bin/clangd # install development dependencies (for testing) RUN --mount=type=cache,target=/root/.cache/uv \ diff --git a/requirements/cpu-build.txt b/requirements/cpu-build.txt index 1ea401a04a12c..a7bd3b17b6323 100644 --- a/requirements/cpu-build.txt +++ b/requirements/cpu-build.txt @@ -1,7 +1,7 @@ cmake>=3.26.1 ninja packaging>=24.2 -setuptools>=77.0.3,<81.0.0 +setuptools==77.0.3 # this version can reuse CMake build dir setuptools-scm>=8 torch==2.9.1+cpu; platform_machine == "x86_64" or platform_machine == "s390x" torch==2.9.1; platform_system == "Darwin" or platform_machine == "ppc64le" or platform_machine == "aarch64" diff --git a/requirements/cpu.txt b/requirements/cpu.txt index 7a670812e8943..111b8a5511562 100644 --- a/requirements/cpu.txt +++ b/requirements/cpu.txt @@ -1,6 +1,8 @@ # Common dependencies -r common.txt +setuptools==77.0.3 # this version can reuse CMake build dir + numba == 0.61.2; platform_machine != "s390x" # Required for N-gram speculative decoding # Dependencies for CPUs diff --git a/tests/kernels/moe/test_cpu_fused_moe.py b/tests/kernels/moe/test_cpu_fused_moe.py new file mode 100644 index 0000000000000..4dda45a6c7409 --- /dev/null +++ b/tests/kernels/moe/test_cpu_fused_moe.py @@ -0,0 +1,172 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from tests.kernels.allclose_default import get_default_atol, get_default_rtol +from vllm._custom_ops import cpu_fused_moe, cpu_prepack_moe_weight +from vllm.model_executor.layers.activation import SiluAndMul, SwigluOAIAndMul +from vllm.platforms import current_platform + +if not current_platform.is_cpu(): + pytest.skip("skipping CPU-only tests", allow_module_level=True) + +EXPERT_NUM = [ + 8, +] +HIDDEN_DIM = [128, 2880] +INTERMEDIATE_DIM = [128, 2880] +BATCH_SIZE = [1, 64, 256] +ACT = ["silu", "swigluoai"] +USE_BIAS = [True, False] +ISA = ["amx", "vec"] if torch._C._cpu._is_amx_tile_supported() else ["vec"] +DTYPE = [torch.bfloat16] + +_CPU_MOE_ACT = { + "silu": SiluAndMul(), + "swigluoai": SwigluOAIAndMul(), +} + + +def ref_fused_moe( + input: torch.Tensor, + w13: torch.Tensor, + w2: torch.Tensor, + w13_bias: torch.Tensor | None, + w2_bias: torch.Tensor | None, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, +) -> torch.Tensor: + len_experts = w13.size(0) + + cnts = topk_ids.new_zeros((topk_ids.shape[0], len_experts)) + cnts.scatter_(1, topk_ids.to(torch.int64), 1) + tokens_per_expert = cnts.sum(dim=0) + idxs = topk_ids.view(-1).argsort() + + sorted_tokens = input[idxs // topk_ids.shape[1]] + tokens_per_expert = tokens_per_expert.cpu().numpy() + + outputs = [] + start_idx = 0 + + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + tokens_for_this_expert = sorted_tokens[start_idx:end_idx].float() + curr_w13 = w13[i].float() + curr_w2 = w2[i].float() + + curr_w13_bias = None + if w13_bias is not None: + curr_w13_bias = w13_bias[i].float() + + curr_w2_bias = None + if w2_bias is not None: + curr_w2_bias = w2_bias[i].float() + + gate_up = torch.nn.functional.linear( + tokens_for_this_expert, curr_w13, curr_w13_bias + ) + # Note: to simulate the kernel implementation + gate_up = ( + _CPU_MOE_ACT[activation] + .forward_native(gate_up) + .to(dtype=input.dtype) + .float() + ) + expert_out = torch.nn.functional.linear(gate_up, curr_w2, curr_w2_bias) + + outputs.append(expert_out) + start_idx = end_idx + + outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) + new_x = torch.empty_like(outs) + + new_x[idxs] = outs + final_out = ( + new_x.view(*topk_ids.shape, -1) + .mul_(topk_weights.unsqueeze(dim=-1)) + .sum(dim=1) + .type(input.dtype) + ) + return final_out + + +@pytest.mark.parametrize("batch_size", BATCH_SIZE) +@pytest.mark.parametrize("expert_num", EXPERT_NUM) +@pytest.mark.parametrize("hidden_size", HIDDEN_DIM) +@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_DIM) +@pytest.mark.parametrize("use_bias", USE_BIAS) +@pytest.mark.parametrize("dtype", DTYPE) +@pytest.mark.parametrize("act", ACT) +@pytest.mark.parametrize("isa", ISA) +def test_cpu_fused_moe( + batch_size: int, + expert_num: int, + hidden_size: int, + intermediate_size: int, + use_bias: bool, + dtype: torch.dtype, + act: str, + isa: str, +): + current_platform.seed_everything(0) + + topk_num = max(expert_num // 2, 1) + up_dim = 2 * intermediate_size + + input = torch.randn((batch_size, hidden_size), dtype=dtype) / ( + 0.5 * hidden_size**0.5 + ) + w13 = torch.randn((expert_num, up_dim, hidden_size), dtype=dtype) / ( + 0.5 * hidden_size**0.5 + ) + w2 = torch.randn((expert_num, hidden_size, intermediate_size), dtype=dtype) / ( + 0.5 * intermediate_size**0.5 + ) + router_logits = torch.randn((batch_size, expert_num), dtype=dtype) + w13_bias = None + w2_bias = None + if use_bias: + w13_bias = torch.randn((expert_num, up_dim), dtype=dtype) / (0.5 * up_dim**0.5) + w2_bias = torch.randn((expert_num, hidden_size), dtype=dtype) / ( + 0.5 * hidden_size**0.5 + ) + score = torch.softmax(router_logits, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk_num) + topk_ids = topk_ids.to(torch.int32) + + ref_output = ref_fused_moe( + input, + w13, + w2, + w13_bias, + w2_bias, + topk_weight, + topk_ids, + act, + ) + + packed_w13 = cpu_prepack_moe_weight(w13, isa) + packed_w2 = cpu_prepack_moe_weight(w2, isa) + output = cpu_fused_moe( + input, + packed_w13, + packed_w2, + w13_bias, + w2_bias, + topk_weight, + topk_ids, + act, + isa, + ) + + atol, rtol = get_default_atol(output), get_default_rtol(output) + ( + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), + f"{torch.max(torch.abs(output - ref_output))}", + ) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 2319655008c50..cf7f17a033be3 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2919,6 +2919,42 @@ def cpu_gemm_wna16( return output +def cpu_prepack_moe_weight( + weight: torch.Tensor, + isa: str, +) -> torch.Tensor: + output = torch.empty_like(weight) + torch.ops._C.prepack_moe_weight(weight, output, isa) + return output + + +def cpu_fused_moe( + input: torch.Tensor, + w13: torch.Tensor, + w2: torch.Tensor, + w13_bias: torch.Tensor | None, + w2_bias: torch.Tensor | None, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + act: str, + isa: str, +) -> torch.Tensor: + output = torch.empty_like(input) + torch.ops._C.cpu_fused_moe( + output, + input, + w13, + w2, + w13_bias, + w2_bias, + topk_weights, + topk_ids, + act, + isa, + ) + return output + + if hasattr(torch.ops._qutlass_C, "matmul_mxf4_bf16_tn"): @register_fake("_qutlass_C::matmul_mxf4_bf16_tn") diff --git a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py index 659a2d4ee5b39..cf7a4313de24c 100644 --- a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py @@ -1,12 +1,22 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import weakref from collections.abc import Callable import torch from torch.nn import functional as F from vllm import _custom_ops as ops +from vllm._custom_ops import cpu_fused_moe, cpu_prepack_moe_weight from vllm.model_executor.layers.activation import SiluAndMul, SwigluOAIAndMul +from vllm.model_executor.layers.quantization.utils.layer_utils import replace_parameter +from vllm.utils.torch_utils import direct_register_custom_op + +_CPU_MOE_LAYER_CACHE = {} +_CPU_MOE_ACT = { + "silu": SiluAndMul(), + "swigluoai": SwigluOAIAndMul(), +} def grouped_topk( @@ -174,8 +184,105 @@ class SGLFusedMOE: class CPUFusedMOE: def __init__(self, layer: torch.nn.Module) -> None: - use_onednn_mm = ops._supports_onednn and ops.is_onednn_acl_supported() + use_grouped_gemm, isa = self.check_grouped_gemm(layer) + self.isa = isa + if use_grouped_gemm: + self.forward_method = self.forward_grouped_gemm + self.init_moe_grouped_gemm(layer=layer) + else: + self.forward_method = self.forward_torch + self.init_moe_torch(layer=layer) + def __call__( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: int | None = None, + num_expert_group: int | None = None, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: torch.Tensor | None = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + assert activation in _CPU_MOE_ACT, f"{activation} is not supported." + assert not apply_router_weight_on_input + + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + ) + + return self.forward_method( + layer, + x, + topk_weights, + topk_ids, + activation, + global_num_experts, + ) + + def check_grouped_gemm( + self, + layer: torch.nn.Module, + ) -> tuple[bool, str]: + if not hasattr(torch.ops._C, "prepack_moe_weight"): + return False, "none" + + dtype = layer.w13_weight.dtype + w13_input_size = layer.w13_weight.size(2) + w13_output_size = layer.w13_weight.size(1) + w2_input_size = layer.w2_weight.size(2) + w2_output_size = layer.w2_weight.size(1) + + if not (w13_output_size % 32 == 0 and w2_output_size % 32 == 0): + return False, "none" + + supports_amx = torch._C._cpu._is_amx_tile_supported() + + if ( + supports_amx + and dtype == torch.bfloat16 + and w13_input_size % 32 == 0 + and w2_input_size % 32 == 0 + ): + return True, "amx" + + if supports_amx: + return False, "none" + + return True, "vec" + + def init_moe_grouped_gemm( + self, + layer: torch.nn.Module, + ) -> None: + new_w13 = cpu_prepack_moe_weight(layer.w13_weight, self.isa) + replace_parameter(layer, "w13_weight", new_w13) + new_w2 = cpu_prepack_moe_weight(layer.w2_weight, self.isa) + replace_parameter(layer, "w2_weight", new_w2) + + def init_moe_torch( + self, + layer: torch.nn.Module, + ) -> None: + use_onednn_mm = ops._supports_onednn and ops.is_onednn_acl_supported() num_experts = layer.w13_weight.size(0) has_w13_bias = hasattr(layer, "w13_bias") has_w2_bias = hasattr(layer, "w2_bias") @@ -208,85 +315,112 @@ class CPUFusedMOE: layer.down_linear.append( lambda x, w=layer_w2_weight, b=layer_w2_bias: F.linear(x, w, b) ) + if use_onednn_mm: # remove weight layer.w13_weight = torch.nn.Parameter(torch.empty(0), requires_grad=False) layer.w2_weight = torch.nn.Parameter(torch.empty(0), requires_grad=False) - self.act_to_impl = { - "silu": SiluAndMul(), - "swigluoai": SwigluOAIAndMul(), - } + _CPU_MOE_LAYER_CACHE[id(layer)] = weakref.ref(layer) - def __call__( + def forward_grouped_gemm( self, layer: torch.nn.Module, - x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: int | None = None, - num_expert_group: int | None = None, + input: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", ) -> torch.Tensor: - assert activation in self.act_to_impl, f"{activation} is not supported." - assert not apply_router_weight_on_input - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias, + output = cpu_fused_moe( + input, + layer.w13_weight, + layer.w2_weight, + getattr(layer, "w13_bias", None), + getattr(layer, "w2_bias", None), + topk_weights, + topk_ids, + activation, + self.isa, + ) + return output + + def forward_torch( + self, + layer: torch.nn.Module, + input: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int = -1, + ) -> torch.Tensor: + output = torch.empty_like(input) + layer_id = id(layer) + torch.ops.vllm.cpu_fused_moe_torch( + layer_id, + output, + input, + topk_weights, + topk_ids, + activation, + global_num_experts, ) - # Ref code from https://github.com/sgl-project/sglang/blob/716e682721397df103f347d22da8bd46c6016dab/python/sglang/srt/layers/moe/fused_moe_native.py#L53 - len_experts = global_num_experts + return output - cnts = topk_ids.new_zeros((topk_ids.shape[0], len_experts)) - cnts.scatter_(1, topk_ids.to(torch.int64), 1) - tokens_per_expert = cnts.sum(dim=0) - idxs = topk_ids.view(-1).argsort() - sorted_tokens = x[idxs // topk_ids.shape[1]] - tokens_per_expert = tokens_per_expert.cpu().numpy() +def cpu_fused_moe_torch( + layer_id: int, + output: torch.Tensor, + input: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int = -1, +) -> None: + layer = _CPU_MOE_LAYER_CACHE[layer_id]() - outputs = [] - start_idx = 0 + # Ref code from https://github.com/sgl-project/sglang/blob/716e682721397df103f347d22da8bd46c6016dab/python/sglang/srt/layers/moe/fused_moe_native.py#L53 + len_experts = global_num_experts - for i, num_tokens in enumerate(tokens_per_expert): - end_idx = start_idx + num_tokens - if num_tokens == 0: - continue - tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + cnts = topk_ids.new_zeros((topk_ids.shape[0], len_experts)) + cnts.scatter_(1, topk_ids.to(torch.int64), 1) + tokens_per_expert = cnts.sum(dim=0) + idxs = topk_ids.view(-1).argsort() - gate_up = layer.gate_up_linear[i](tokens_for_this_expert) - gate_up = self.act_to_impl[activation].forward_native(gate_up) - expert_out = layer.down_linear[i](gate_up) - outputs.append(expert_out) - start_idx = end_idx + sorted_tokens = input[idxs // topk_ids.shape[1]] + tokens_per_expert = tokens_per_expert.cpu().numpy() - outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) - new_x = torch.empty_like(outs) + outputs = [] + start_idx = 0 - new_x[idxs] = outs - final_out = ( - new_x.view(*topk_ids.shape, -1) - .type(topk_weights.dtype) - .mul_(topk_weights.unsqueeze(dim=-1)) - .sum(dim=1) - .type(new_x.dtype) - ) - return final_out + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + + gate_up = layer.gate_up_linear[i](tokens_for_this_expert) # type: ignore + gate_up = _CPU_MOE_ACT[activation].forward_native(gate_up) + expert_out = layer.down_linear[i](gate_up) # type: ignore + outputs.append(expert_out) + start_idx = end_idx + + outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) + new_x = torch.empty_like(outs) + + new_x[idxs] = outs + final_out = ( + new_x.view(*topk_ids.shape, -1) + .type(topk_weights.dtype) + .mul_(topk_weights.unsqueeze(dim=-1)) + .sum(dim=1) + .type(new_x.dtype) + ) + output.copy_(final_out) + + +direct_register_custom_op( + op_name="cpu_fused_moe_torch", + op_func=cpu_fused_moe_torch, + mutates_args=["output"], +) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index db97d6eb88ea5..6a65b06014bca 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1726,9 +1726,10 @@ class FusedMoE(CustomOp): return states if self.shared_experts is None: - if current_platform.is_tpu(): + if current_platform.is_tpu() or current_platform.is_cpu(): # TODO: Once the OOM issue for the TPU backend is resolved, we # will switch to using the moe_forward custom op. + # Note: CPU doesn't require wrapped forward_impl. fused_output = self.forward_impl(hidden_states, router_logits) assert not isinstance(fused_output, tuple) else: @@ -1744,9 +1745,10 @@ class FusedMoE(CustomOp): else: return reduce_output(fused_output)[..., :og_hidden_states] else: - if current_platform.is_tpu(): + if current_platform.is_tpu() or current_platform.is_cpu(): # TODO: Once the OOM issue for the TPU backend is resolved, we # will switch to using the moe_forward custom op. + # Note: CPU doesn't require wrapped forward_impl. shared_output, fused_output = self.forward_impl( hidden_states, router_logits )