#ifndef CPU_ATTN_AMX_HPP #define CPU_ATTN_AMX_HPP #include "cpu_attn_impl.hpp" namespace cpu_attention { namespace { // AMX specific constexpr static int64_t AMX_TILE_ROW_BYTES = 64; constexpr static int64_t AMX_TILE_ROW_NUM = 16; constexpr static int64_t AMX_TILE_BYTES = AMX_TILE_ROW_BYTES * AMX_TILE_ROW_NUM; typedef struct __tile_config { uint8_t palette_id = 1; uint8_t start_row = 0; uint8_t reserved_0[14] = {0}; uint16_t colsb[16] = {0}; uint8_t rows[16] = {0}; } __tilecfg; // 2-2-4 pattern, for 16 < m <= 32 // TILE 0, 1: load A matrix, row num should be 16, m - 16 // TILE 2, 3: load B matrix, row num should be 16 // TILE 4, 5, 6, 7: store results C matrix, row num should be 16, 16, m - 16, m // - 16 template class TileGemm224 { public: template FORCE_INLINE static void gemm(const int32_t m_size, void* __restrict__ a_tile, void* __restrict__ b_tile, float* __restrict__ c_tile, const int64_t lda, const int64_t ldb, const int64_t ldc, const int32_t block_size, const int32_t dynamic_k_size, const bool accum_c) { TORCH_CHECK(false, "Unsupported kv cache type for TileGemm224"); } FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) { TORCH_CHECK(false, "Unsupported kv cache type for TileGemm224"); } }; template <> class TileGemm224 { public: template FORCE_INLINE static void gemm(const int32_t m_size, c10::BFloat16* __restrict__ a_tile, c10::BFloat16* __restrict__ b_tile, float* __restrict__ c_tile, const int64_t lda, const int64_t ldb, const int64_t ldc, const int32_t block_size, const int32_t dynamic_k_size, const bool accum_c) { const int32_t k_times = dynamic_k_size / (AMX_TILE_ROW_NUM * 4 / sizeof(c10::BFloat16)); c10::BFloat16* __restrict__ a_tile_0 = a_tile; c10::BFloat16* __restrict__ a_tile_1 = a_tile + lda * AMX_TILE_ROW_NUM; const int64_t a_tile_stride = [&]() { if constexpr (phase == AttentionGemmPhase::QK) { // q_buffer is prepacked return AMX_TILE_ROW_BYTES; } else if constexpr (phase == AttentionGemmPhase::PV) { // logits_buffer is row-major return lda * sizeof(c10::BFloat16); } else { TORCH_CHECK(false, "Unreachable"); } }(); c10::BFloat16* __restrict__ b_tile_2 = b_tile; c10::BFloat16* __restrict__ b_tile_3 = [&]() { if constexpr (phase == AttentionGemmPhase::QK) { // k_cache is prepacked return b_tile + (k_size * AMX_TILE_ROW_BYTES / 4); } else if constexpr (phase == AttentionGemmPhase::PV) { // v_cache is prepacked return b_tile + (block_size * AMX_TILE_ROW_BYTES / 4); } else { TORCH_CHECK(false, "Unreachable"); } }(); // k_cache, v_cache are prepacked const int32_t b_tile_stride = AMX_TILE_ROW_BYTES; // logits_buffer, output_buffer are not prepacked float* __restrict__ c_tile_4 = c_tile; float* __restrict__ c_tile_5 = c_tile_4 + AMX_TILE_ROW_BYTES / sizeof(float); float* __restrict__ c_tile_6 = c_tile + AMX_TILE_ROW_NUM * ldc; float* __restrict__ c_tile_7 = c_tile_6 + AMX_TILE_ROW_BYTES / sizeof(float); const int32_t c_tile_stride = ldc * sizeof(float); if (accum_c) { _tile_loadd(4, c_tile_4, c_tile_stride); _tile_loadd(5, c_tile_5, c_tile_stride); _tile_loadd(6, c_tile_6, c_tile_stride); _tile_loadd(7, c_tile_7, c_tile_stride); } else { _tile_zero(4); _tile_zero(5); _tile_zero(6); _tile_zero(7); } for (int32_t k = 0; k < k_times; ++k) { _tile_loadd(0, a_tile_0, a_tile_stride); _tile_stream_loadd(2, b_tile_2, b_tile_stride); _tile_dpbf16ps(4, 0, 2); _tile_stream_loadd(3, b_tile_3, b_tile_stride); _tile_dpbf16ps(5, 0, 3); _tile_loadd(1, a_tile_1, a_tile_stride); _tile_dpbf16ps(6, 1, 2); _tile_dpbf16ps(7, 1, 3); // update ptrs if constexpr (phase == AttentionGemmPhase::QK) { // Q buffer is prepacked a_tile_0 += AMX_TILE_BYTES / sizeof(c10::BFloat16); a_tile_1 += AMX_TILE_BYTES / sizeof(c10::BFloat16); } else if constexpr (phase == AttentionGemmPhase::PV) { // P buffer is not prepacked a_tile_0 += AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16); a_tile_1 += AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16); } else { TORCH_CHECK(false, "Unreachable"); } b_tile_2 += AMX_TILE_BYTES / sizeof(c10::BFloat16); b_tile_3 += AMX_TILE_BYTES / sizeof(c10::BFloat16); } _tile_stored(4, c_tile_4, c_tile_stride); _tile_stored(5, c_tile_5, c_tile_stride); _tile_stored(6, c_tile_6, c_tile_stride); _tile_stored(7, c_tile_7, c_tile_stride); } FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) { const int32_t m_0 = AMX_TILE_ROW_NUM; const int32_t m_1 = m - AMX_TILE_ROW_NUM; config.rows[0] = m_0; config.rows[1] = m_1; config.rows[2] = AMX_TILE_ROW_NUM; config.rows[3] = AMX_TILE_ROW_NUM; config.rows[4] = m_0; config.rows[5] = m_0; config.rows[6] = m_1; config.rows[7] = m_1; _tile_loadconfig(&config); } }; // 1-2-2 pattern, for 0 < m <= 16 // TILE 0, (1): load A matrix, use extra 1 tile for prefetch, row num should be // m, m // TILE 2, 3, (4, 5): load B matrix, use extra 2 tiles for prefetch, row // num should be 16 // TILE 6, 7, (6, 7): store results C matrix, row num should be // m template class TileGemm122 { public: template FORCE_INLINE static void gemm(const int32_t m_size, void* __restrict__ a_tile, void* __restrict__ b_tile, float* __restrict__ c_tile, const int64_t lda, const int64_t ldb, const int64_t ldc, const int32_t block_size, const int32_t dynamic_k_size, const bool accum_c) { TORCH_CHECK(false, "Unsupported kv cache type for TileGemm122"); } FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) { TORCH_CHECK(false, "Unsupported kv cache type for TileGemm122"); } }; template <> class TileGemm122 { public: template FORCE_INLINE static void gemm(const int32_t m_size, c10::BFloat16* __restrict__ a_tile, c10::BFloat16* __restrict__ b_tile, float* __restrict__ c_tile, const int64_t lda, const int64_t ldb, const int64_t ldc, const int32_t block_size, const int32_t dynamic_k_size, const bool accum_c) { c10::BFloat16* __restrict__ a_tile_0 = a_tile; c10::BFloat16* __restrict__ a_tile_1 = [&]() { if constexpr (phase == AttentionGemmPhase::QK) { // q_buffer is prepacked return a_tile + AMX_TILE_BYTES / sizeof(c10::BFloat16); } else if constexpr (phase == AttentionGemmPhase::PV) { // logits_buffer is row-major return a_tile + AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16); } else { TORCH_CHECK(false, "Unreachable"); } }(); const int64_t a_tile_stride = [&]() { if constexpr (phase == AttentionGemmPhase::QK) { // q_buffer is prepacked return AMX_TILE_ROW_BYTES; } else if constexpr (phase == AttentionGemmPhase::PV) { // logits_buffer is row-major return lda * sizeof(c10::BFloat16); } else { TORCH_CHECK(false, "Unreachable"); } }(); c10::BFloat16* __restrict__ b_tile_2 = b_tile; c10::BFloat16* __restrict__ b_tile_3 = [&]() { if constexpr (phase == AttentionGemmPhase::QK) { // k_cache is prepacked return b_tile + (k_size * AMX_TILE_ROW_BYTES / 4); } else if constexpr (phase == AttentionGemmPhase::PV) { // v_cache is prepacked return b_tile + (block_size * AMX_TILE_ROW_BYTES / 4); } else { TORCH_CHECK(false, "Unreachable"); } }(); c10::BFloat16* __restrict__ b_tile_4 = b_tile_2 + AMX_TILE_BYTES / sizeof(c10::BFloat16); c10::BFloat16* __restrict__ b_tile_5 = b_tile_3 + AMX_TILE_BYTES / sizeof(c10::BFloat16); int64_t b_stride = AMX_TILE_ROW_BYTES; float* __restrict__ c_tile_6 = c_tile; float* __restrict__ c_tile_7 = c_tile + AMX_TILE_ROW_BYTES / sizeof(float); int64_t c_stride = ldc * sizeof(float); const int32_t k_times = dynamic_k_size / (AMX_TILE_ROW_NUM * 4 / sizeof(c10::BFloat16)); const int32_t k_group_times = k_times / 2; const bool has_tail = (k_times % 2 == 1); if (accum_c) { _tile_loadd(6, c_tile_6, c_stride); _tile_loadd(7, c_tile_7, c_stride); } else { _tile_zero(6); _tile_zero(7); } for (int32_t k = 0; k < k_group_times; ++k) { _tile_loadd(0, a_tile_0, a_tile_stride); _tile_stream_loadd(2, b_tile_2, b_stride); _tile_dpbf16ps(6, 0, 2); _tile_stream_loadd(3, b_tile_3, b_stride); _tile_dpbf16ps(7, 0, 3); _tile_loadd(1, a_tile_1, a_tile_stride); _tile_stream_loadd(4, b_tile_4, b_stride); _tile_dpbf16ps(6, 1, 4); _tile_stream_loadd(5, b_tile_5, b_stride); _tile_dpbf16ps(7, 1, 5); // update ptrs if constexpr (phase == AttentionGemmPhase::QK) { // Q buffer is prepacked a_tile_0 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16); a_tile_1 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16); } else if constexpr (phase == AttentionGemmPhase::PV) { // P buffer is not prepacked a_tile_0 += 2 * AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16); a_tile_1 += 2 * AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16); } b_tile_2 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16); b_tile_3 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16); b_tile_4 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16); b_tile_5 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16); } if (has_tail) { _tile_loadd(0, a_tile_0, a_tile_stride); _tile_stream_loadd(2, b_tile_2, b_stride); _tile_dpbf16ps(6, 0, 2); _tile_stream_loadd(3, b_tile_3, b_stride); _tile_dpbf16ps(7, 0, 3); } _tile_stored(6, c_tile_6, c_stride); _tile_stored(7, c_tile_7, c_stride); } FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) { config.rows[0] = m; config.rows[1] = m; config.rows[2] = AMX_TILE_ROW_NUM; config.rows[3] = AMX_TILE_ROW_NUM; config.rows[4] = AMX_TILE_ROW_NUM; config.rows[5] = AMX_TILE_ROW_NUM; config.rows[6] = m; config.rows[7] = m; _tile_loadconfig(&config); } }; } // namespace template class AttentionImpl { public: using query_t = scalar_t; using q_buffer_t = scalar_t; using kv_cache_t = scalar_t; using logits_buffer_t = float; using partial_output_buffer_t = float; using prob_buffer_t = scalar_t; constexpr static int64_t BlockSizeAlignment = AMX_TILE_ROW_BYTES / sizeof(kv_cache_t); // KV token num unit of QK and PV phases constexpr static int64_t HeadDimAlignment = 2 * (AMX_TILE_ROW_BYTES / 4); // headdim num unit of PV phase constexpr static int64_t MaxQHeadNumPerIteration = 32; constexpr static int64_t HeadDim = head_dim; constexpr static ISA ISAType = ISA::AMX; constexpr static bool scale_on_logits = true; public: AttentionImpl() : current_q_head_num_(0) { // Use all columns in AMX tiles vec_op::unroll_loop([&](int i) { amx_tile_config_.colsb[i] = 64; }); } ~AttentionImpl() { _tile_release(); } template