#ifndef CPU_ATTN_VEC_HPP #define CPU_ATTN_VEC_HPP #include "cpu_attn_impl.hpp" namespace cpu_attention { namespace { // 8-2-16 pattern, 8 regs for A, 2 regs for B, 16 regs for C, [8, K] @ [k, 32] template class TileGemm82 { public: template FORCE_INLINE static void gemm(const int32_t m_size, float* __restrict__ a_tile, kv_cache_t* __restrict__ b_tile, float* __restrict__ c_tile, const int64_t lda, const int64_t ldb, const int64_t ldc, const int32_t block_size, const int32_t dynamic_k_size, const bool accum_c) { switch (m_size) { case 1: gemm_micro<1>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size, dynamic_k_size, accum_c); break; case 2: gemm_micro<2>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size, dynamic_k_size, accum_c); break; case 3: case 4: gemm_micro<4>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size, dynamic_k_size, accum_c); break; case 5: case 6: gemm_micro<6>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size, dynamic_k_size, accum_c); break; case 7: case 8: gemm_micro<8>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size, dynamic_k_size, accum_c); break; } } template static void gemm_micro(float* __restrict__ a_tile, kv_cache_t* __restrict__ b_tile, float* __restrict__ c_tile, const int64_t lda, const int64_t ldb, const int64_t ldc, const int32_t block_size, const int32_t dynamic_k_size, const bool accum_c) { static_assert(0 < M <= 8); using load_vec_t = typename VecTypeTrait::vec_t; kv_cache_t* __restrict__ curr_b_0 = b_tile; kv_cache_t* __restrict__ curr_b_1 = b_tile + 16; float* __restrict__ curr_c_0 = c_tile; float* __restrict__ curr_c_1 = c_tile + 16; vec_op::FP32Vec16 c_regs[M * 2]; if (accum_c) { float* __restrict__ curr_m_c_0 = curr_c_0; float* __restrict__ curr_m_c_1 = curr_c_1; vec_op::unroll_loop([&](int32_t i) { c_regs[i * 2] = vec_op::FP32Vec16(curr_m_c_0); c_regs[i * 2 + 1] = vec_op::FP32Vec16(curr_m_c_1); // update curr_m_c_0 += ldc; curr_m_c_1 += ldc; }); } float* __restrict__ curr_a = a_tile; for (int32_t k = 0; k < dynamic_k_size; ++k) { load_vec_t b_0_reg(curr_b_0); vec_op::FP32Vec16 fp32_b_0_reg(b_0_reg); load_vec_t b_1_reg(curr_b_1); vec_op::FP32Vec16 fp32_b_1_reg(b_1_reg); float* __restrict__ curr_m_a = curr_a; vec_op::unroll_loop([&](int32_t i) { float v = *curr_m_a; vec_op::FP32Vec16 a_reg(v); c_regs[i * 2] = c_regs[i * 2] + a_reg * fp32_b_0_reg; c_regs[i * 2 + 1] = c_regs[i * 2 + 1] + a_reg * fp32_b_1_reg; // update curr_m_a += lda; }); // update curr_a += 1; curr_b_0 += ldb; curr_b_1 += ldb; } vec_op::unroll_loop([&](int32_t i) { c_regs[i * 2].save(curr_c_0); c_regs[i * 2 + 1].save(curr_c_1); // update curr_c_0 += ldc; curr_c_1 += ldc; }); } }; } // namespace // This is a general but naive implementation based on vector instructions template class AttentionImpl { public: using query_t = scalar_t; using q_buffer_t = float; using kv_cache_t = scalar_t; using logits_buffer_t = float; using partial_output_buffer_t = float; using prob_buffer_t = float; constexpr static int64_t BlockSizeAlignment = 32; // KV token num unit of QK and PV phases constexpr static int64_t HeadDimAlignment = 32; // headdim num unit of PV phase constexpr static int64_t MaxQHeadNumPerIteration = 8; constexpr static int64_t HeadDim = head_dim; constexpr static ISA ISAType = ISA::VEC; constexpr static bool scale_on_logits = false; // apply scale on q_buffer public: template