From 9843e332da56307597deb2739bb83b85c18c5dde Mon Sep 17 00:00:00 2001 From: Elham Date: Fri, 5 Dec 2025 08:09:20 -0500 Subject: [PATCH] [CPU][Perf] Add fast vectorized exp impl from Arm Optimized Routines (#30068) Signed-off-by: Ubuntu Signed-off-by: Elham Harirpoush Co-authored-by: Ubuntu --- csrc/cpu/cpu_attn_impl.hpp | 13 ---------- csrc/cpu/cpu_attn_macros.h | 50 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 13 deletions(-) diff --git a/csrc/cpu/cpu_attn_impl.hpp b/csrc/cpu/cpu_attn_impl.hpp index 98f55d7c014be..02164ed3666e3 100644 --- a/csrc/cpu/cpu_attn_impl.hpp +++ b/csrc/cpu/cpu_attn_impl.hpp @@ -1246,14 +1246,8 @@ class AttentionMainLoop { // rescale sum and partial outputs if (need_rescale) { // compute rescale factor -#ifdef DEFINE_FAST_EXP - vec_op::FP32Vec16 rescale_factor_vec(rescale_factor); - rescale_factor_vec = fast_exp(rescale_factor_vec); - rescale_factor = rescale_factor_vec.get_last_elem(); -#else rescale_factor = std::exp(rescale_factor); vec_op::FP32Vec16 rescale_factor_vec(rescale_factor); -#endif // rescale sum new_sum_val += rescale_factor * init_sum_val; @@ -1889,15 +1883,8 @@ class AttentionMainLoop { : curr_output_buffer; float rescale_factor = final_max > curr_max ? curr_max - final_max : final_max - curr_max; - -#ifdef DEFINE_FAST_EXP - vec_op::FP32Vec16 rescale_factor_vec(rescale_factor); - rescale_factor_vec = fast_exp(rescale_factor_vec); - rescale_factor = rescale_factor_vec.get_last_elem(); -#else rescale_factor = std::exp(rescale_factor); vec_op::FP32Vec16 rescale_factor_vec(rescale_factor); -#endif local_sum[head_idx] = final_max > curr_max ? final_sum + rescale_factor * curr_sum diff --git a/csrc/cpu/cpu_attn_macros.h b/csrc/cpu/cpu_attn_macros.h index 6458e43419370..35716a0790ab3 100644 --- a/csrc/cpu/cpu_attn_macros.h +++ b/csrc/cpu/cpu_attn_macros.h @@ -60,4 +60,54 @@ #endif +#ifdef __aarch64__ + // Implementation copied from Arm Optimized Routines (expf AdvSIMD) + // https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/advsimd/expf.c + #include + #define DEFINE_FAST_EXP \ + const float32x4_t inv_ln2 = vdupq_n_f32(0x1.715476p+0f); \ + const float ln2_hi = 0x1.62e4p-1f; \ + const float ln2_lo = 0x1.7f7d1cp-20f; \ + const float c0 = 0x1.0e4020p-7f; \ + const float c2 = 0x1.555e66p-3f; \ + const float32x4_t ln2_c02 = {ln2_hi, ln2_lo, c0, c2}; \ + const uint32x4_t exponent_bias = vdupq_n_u32(0x3f800000); \ + const float32x4_t c1 = vdupq_n_f32(0x1.573e2ep-5f); \ + const float32x4_t c3 = vdupq_n_f32(0x1.fffdb6p-2f); \ + const float32x4_t c4 = vdupq_n_f32(0x1.ffffecp-1f); \ + const float32x4_t pos_special_bound = vdupq_n_f32(0x1.5d5e2ap+6f); \ + const float32x4_t neg_special_bound = vnegq_f32(pos_special_bound); \ + const float32x4_t inf = \ + vdupq_n_f32(std::numeric_limits::infinity()); \ + const float32x4_t zero = vdupq_n_f32(0.0f); \ + auto neon_expf = [&](float32x4_t values) __attribute__((always_inline)) { \ + float32x4_t n = vrndaq_f32(vmulq_f32(values, inv_ln2)); \ + float32x4_t r = vfmsq_laneq_f32(values, n, ln2_c02, 0); \ + r = vfmsq_laneq_f32(r, n, ln2_c02, 1); \ + uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_s32(vcvtq_s32_f32(n)), 23); \ + float32x4_t scale = vreinterpretq_f32_u32(vaddq_u32(e, exponent_bias)); \ + float32x4_t r2 = vmulq_f32(r, r); \ + float32x4_t p = vfmaq_laneq_f32(c1, r, ln2_c02, 2); \ + float32x4_t q = vfmaq_laneq_f32(c3, r, ln2_c02, 3); \ + q = vfmaq_f32(q, p, r2); \ + p = vmulq_f32(c4, r); \ + float32x4_t poly = vfmaq_f32(p, q, r2); \ + poly = vfmaq_f32(scale, poly, scale); \ + const uint32x4_t hi_mask = vcgeq_f32(values, pos_special_bound); \ + const uint32x4_t lo_mask = vcleq_f32(values, neg_special_bound); \ + poly = vbslq_f32(hi_mask, inf, poly); \ + return vbslq_f32(lo_mask, zero, poly); \ + }; \ + auto fast_exp = [&](vec_op::FP32Vec16& vec) \ + __attribute__((always_inline)) { \ + float32x4x4_t result; \ + result.val[0] = neon_expf(vec.reg.val[0]); \ + result.val[1] = neon_expf(vec.reg.val[1]); \ + result.val[2] = neon_expf(vec.reg.val[2]); \ + result.val[3] = neon_expf(vec.reg.val[3]); \ + return vec_op::FP32Vec16(result); \ + }; + +#endif // __aarch64__ + #endif \ No newline at end of file