#ifndef CPU_ATTN_VEC16_HPP #define CPU_ATTN_VEC16_HPP #include "cpu_attn_vec.hpp" namespace cpu_attention { namespace { // 16-1-16 pattern, 16 regs for A, 1 regs for B, 16 regs for C, [16, K] @ [k, // 16] template class TileGemm161 { public: template FORCE_INLINE static void gemm(const int32_t m_size, float* __restrict__ a_tile, kv_cache_t* __restrict__ b_tile, float* __restrict__ c_tile, const int64_t lda, const int64_t ldb, const int64_t ldc, const int32_t block_size, const int32_t dynamic_k_size, const bool accum_c) { switch (m_size) { case 1: gemm_micro<1>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size, dynamic_k_size, accum_c); break; case 2: gemm_micro<2>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size, dynamic_k_size, accum_c); break; case 3: case 4: gemm_micro<4>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size, dynamic_k_size, accum_c); break; case 5: case 6: gemm_micro<6>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size, dynamic_k_size, accum_c); break; case 7: case 8: gemm_micro<8>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size, dynamic_k_size, accum_c); break; case 9: case 10: case 11: case 12: gemm_micro<12>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size, dynamic_k_size, accum_c); break; case 13: case 14: case 15: case 16: gemm_micro<16>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size, dynamic_k_size, accum_c); break; } } template static void gemm_micro(float* __restrict__ a_tile, kv_cache_t* __restrict__ b_tile, float* __restrict__ c_tile, const int64_t lda, const int64_t ldb, const int64_t ldc, const int32_t block_size, const int32_t dynamic_k_size, const bool accum_c) { static_assert(0 < M <= 16); using load_vec_t = typename VecTypeTrait::vec_t; kv_cache_t* __restrict__ curr_b_0 = b_tile; float* __restrict__ curr_c_0 = c_tile; vec_op::FP32Vec16 c_regs[M]; if (accum_c) { float* __restrict__ curr_m_c_0 = curr_c_0; vec_op::unroll_loop([&](int32_t i) { c_regs[i] = vec_op::FP32Vec16(curr_m_c_0); // update curr_m_c_0 += ldc; }); } float* __restrict__ curr_a = a_tile; for (int32_t k = 0; k < dynamic_k_size; ++k) { load_vec_t b_0_reg(curr_b_0); vec_op::FP32Vec16 fp32_b_0_reg(b_0_reg); float* __restrict__ curr_m_a = curr_a; vec_op::unroll_loop([&](int32_t i) { float v = *curr_m_a; vec_op::FP32Vec16 a_reg(v); c_regs[i] = c_regs[i] + a_reg * fp32_b_0_reg; // update curr_m_a += lda; }); // update curr_a += 1; curr_b_0 += ldb; } vec_op::unroll_loop([&](int32_t i) { c_regs[i].save(curr_c_0); // update curr_c_0 += ldc; }); } }; } // namespace // This is a general but naive implementation based on vector instructions template class AttentionImpl : public AttentionImpl { public: using query_t = scalar_t; using q_buffer_t = float; using kv_cache_t = scalar_t; using logits_buffer_t = float; using partial_output_buffer_t = float; using prob_buffer_t = float; constexpr static int64_t BlockSizeAlignment = 16; // KV token num unit of QK and PV phases constexpr static int64_t HeadDimAlignment = 16; // headdim num unit of PV phase constexpr static int64_t MaxQHeadNumPerIteration = 16; constexpr static int64_t HeadDim = head_dim; constexpr static ISA ISAType = ISA::VEC16; constexpr static bool scale_on_logits = false; // apply scale on q_buffer public: template