[CPU][IBM Z] Fix BF16 support and vectorize math operations for s390x (#28926)

Signed-off-by: Rehan Khan <Rehan.Khan7@ibm.com>
This commit is contained in:
R3hankhan 2025-11-24 17:38:09 +05:30 committed by GitHub
parent eca7a8fb59
commit 4de87866a8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 531 additions and 57 deletions

View File

@ -847,7 +847,7 @@ struct VecTypeTrait<c10::BFloat16> {
};
#endif
#if !defined(__powerpc__)
#if !defined(__powerpc__) && !defined(__s390x__)
template <>
struct VecTypeTrait<c10::Half> {
using vec_t = vec_op::FP16Vec16;

View File

@ -4,6 +4,7 @@
#include <vecintrin.h>
#include <cmath>
#include <limits>
#include <torch/all.h>
namespace vec_op {
@ -174,8 +175,9 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
}
explicit FP32Vec8(const BF16Vec8& v) {
reg.val[0] = (__vector float)vec_mergeh(zero, v.reg);
reg.val[1] = (__vector float)vec_mergel(zero, v.reg);
// On big-endian s390x, place BF16 first to get correct byte order
reg.val[0] = (__vector float)vec_mergeh(v.reg, zero);
reg.val[1] = (__vector float)vec_mergel(v.reg, zero);
}
float reduce_sum() const {
@ -189,51 +191,257 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
}
FP32Vec8 exp() const {
// TODO: Vectorize this
AliasReg ar;
ar.reg = reg;
f32x4x4_t ret;
ret.val[0][0] = std::exp(ar.values[0]);
ret.val[0][1] = std::exp(ar.values[1]);
ret.val[0][2] = std::exp(ar.values[2]);
ret.val[0][3] = std::exp(ar.values[3]);
ret.val[1][0] = std::exp(ar.values[4]);
ret.val[1][1] = std::exp(ar.values[5]);
ret.val[1][2] = std::exp(ar.values[6]);
ret.val[1][3] = std::exp(ar.values[7]);
return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
f32x4x2_t out;
const __vector float log2e = vec_splats(1.44269504088896341f);
const __vector float one = vec_splats(1.0f);
const __vector float min_x = vec_splats(-87.3f);
const __vector float max_x = vec_splats(88.7f);
// 5th-degree minimax polynomial for 2^r (r in [0,1))
const __vector float c1 = vec_splats(0.6931471805599453f);
const __vector float c2 = vec_splats(0.240226506959101f);
const __vector float c3 = vec_splats(0.05550410866482158f);
const __vector float c4 = vec_splats(0.009618129107628477f);
const __vector float c5 = vec_splats(0.0013333558146428443f);
for (int i = 0; i < 2; i++) {
__vector float x = reg.val[i];
x = vec_max(x, min_x);
x = vec_min(x, max_x);
__vector float y = vec_mul(x, log2e);
__vector float kf = vec_floor(y);
__vector float r = vec_sub(y, kf);
__vector signed int k = vec_signed(kf);
const __vector signed int min_k = vec_splats((signed int)-126);
const __vector signed int max_k = vec_splats((signed int)127);
k = vec_min(vec_max(k, min_k), max_k);
// Build 2^k from exponent bits
__vector signed int exp_int = vec_add(k, vec_splats((signed int)127));
__vector unsigned int bits = (__vector unsigned int)exp_int;
bits = vec_sl(bits, vec_splats((unsigned int)23));
__vector float pow2k = (__vector float)bits;
// Improved minimax polynomial
__vector float poly = vec_madd(c5, r, c4);
poly = vec_madd(poly, r, c3);
poly = vec_madd(poly, r, c2);
poly = vec_madd(poly, r, c1);
poly = vec_madd(poly, r, one);
out.val[i] = vec_mul(pow2k, poly);
}
return FP32Vec8(out);
}
FP32Vec8 tanh() const {
// TODO: Vectorize this
AliasReg ar;
ar.reg = reg;
f32x4x4_t ret;
ret.val[0][0] = std::tanh(ar.values[0]);
ret.val[0][1] = std::tanh(ar.values[1]);
ret.val[0][2] = std::tanh(ar.values[2]);
ret.val[0][3] = std::tanh(ar.values[3]);
ret.val[1][0] = std::tanh(ar.values[4]);
ret.val[1][1] = std::tanh(ar.values[5]);
ret.val[1][2] = std::tanh(ar.values[6]);
ret.val[1][3] = std::tanh(ar.values[7]);
return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
// tanh(x) = (exp(2x) - 1) / (exp(2x) + 1)
const __vector float one = vec_splats(1.0f);
const __vector float two = vec_splats(2.0f);
const __vector float zero = vec_splats(0.0f);
const __vector float sat =
vec_splats(9.0f); // beyond this, tanh(x) ~ sign(x)
f32x4x2_t out;
for (int i = 0; i < 2; i++) {
__vector float x = reg.val[i];
__vector float ax = vec_abs(x);
// sign(x): +1 or -1
__vector float sign = vec_sel(vec_splats(-1.0f), one, vec_cmpgt(x, zero));
// saturation mask: |x| > sat
__vector __bool int saturated = vec_cmpgt(ax, sat);
// 2x
__vector float two_x = vec_mul(x, two);
// Build a temporary FP32Vec8 with both lanes = 2x, reuse exp()
f32x4x2_t tmp;
tmp.val[0] = two_x;
tmp.val[1] = two_x;
FP32Vec8 exp_2x_vec(tmp);
FP32Vec8 e2x = exp_2x_vec.exp();
__vector float e = e2x.reg.val[i];
// tanh(x) = (e - 1) / (e + 1)
__vector float num = vec_sub(e, one);
__vector float den = vec_add(e, one);
__vector float t = vec_div(num, den);
// For large |x|, clamp to sign(x)
out.val[i] = vec_sel(t, sign, saturated);
}
return FP32Vec8(out);
}
FP32Vec8 er() const {
// TODO: Vectorize this
AliasReg ar;
ar.reg = reg;
f32x4x4_t ret;
ret.val[0][0] = std::erf(ar.values[0]);
ret.val[0][1] = std::erf(ar.values[1]);
ret.val[0][2] = std::erf(ar.values[2]);
ret.val[0][3] = std::erf(ar.values[3]);
ret.val[1][0] = std::erf(ar.values[4]);
ret.val[1][1] = std::erf(ar.values[5]);
ret.val[1][2] = std::erf(ar.values[6]);
ret.val[1][3] = std::erf(ar.values[7]);
return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
// A&S 7.1.26 approximation:
// erf(x) = sign(x) * (1 - ((((a5*t + a4)*t + a3)*t + a2)*t + a1) * t *
// exp(-x^2)) t = 1 / (1 + p*|x|), p = 0.3275911
const __vector float one = vec_splats(1.0f);
const __vector float zero = vec_splats(0.0f);
const __vector float p = vec_splats(0.3275911f);
// Polynomial coeffs
const __vector float a1 = vec_splats(0.254829592f);
const __vector float a2 = vec_splats(-0.284496736f);
const __vector float a3 = vec_splats(1.421413741f);
const __vector float a4 = vec_splats(-1.453152027f);
const __vector float a5 = vec_splats(1.061405429f);
// Threshold where erf(x) ~ sign(x)
const __vector float sat = vec_splats(6.0f);
f32x4x2_t out;
for (int lane = 0; lane < 2; lane++) {
__vector float x = reg.val[lane];
__vector float ax = vec_abs(x);
// sign(x)
__vector float sign = vec_sel(vec_splats(-1.0f), one, vec_cmpgt(x, zero));
// |x| > 6 → erf(x) = ±1
__vector __bool int saturated = vec_cmpgt(ax, sat);
// t = 1 / (1 + p * |x|)
__vector float t = vec_madd(p, ax, one);
t = vec_div(one, t);
// poly = a5
__vector float poly = a5;
poly = vec_madd(poly, t, a4);
poly = vec_madd(poly, t, a3);
poly = vec_madd(poly, t, a2);
poly = vec_madd(poly, t, a1);
// full polynomial: poly = poly * t
poly = vec_mul(poly, t);
// Compute exp(-x^2)
__vector float x2 = vec_mul(x, x);
__vector float neg_x2 = vec_neg(x2);
f32x4x2_t tmp;
tmp.val[0] = neg_x2;
tmp.val[1] = neg_x2;
FP32Vec8 exp_neg_x2(tmp);
FP32Vec8 e = exp_neg_x2.exp();
__vector float ex = e.reg.val[lane];
// erf(x) = sign * (1 - poly * exp(-x^2))
__vector float term = vec_mul(poly, ex);
__vector float y = vec_sub(one, term);
y = vec_mul(y, sign);
// saturated → ±1
__vector float sat_val = vec_mul(sign, one);
out.val[lane] = vec_sel(y, sat_val, saturated);
}
return FP32Vec8(out);
}
// Elementwise sigmoid(x) = 1 / (1 + exp(-x))
FP32Vec8 sigmoid() const {
const __vector float one = vec_splats(1.0f);
f32x4x2_t neg;
for (int i = 0; i < 2; ++i) {
neg.val[i] = vec_neg(reg.val[i]);
}
FP32Vec8 neg_x(neg);
FP32Vec8 e = neg_x.exp(); // exp(-x)
f32x4x2_t denom;
for (int i = 0; i < 2; ++i) {
denom.val[i] = vec_add(one, e.reg.val[i]);
}
FP32Vec8 denom_vec(denom);
FP32Vec8 one_vec(1.0f);
return one_vec / denom_vec;
}
// Tanh-based GELU:
// gelu(x) = 0.5 * x * (1 + tanh(√(2/π) * (x + 0.044715 * x^3)))
FP32Vec8 gelu_tanh() const {
const __vector float k_s2pi = vec_splats(0.7978845608028654f); // √(2/π)
const __vector float k_0_0447 = vec_splats(0.044715f);
f32x4x2_t x2, x3, inner;
for (int i = 0; i < 2; ++i) {
__vector float x = reg.val[i];
x2.val[i] = vec_mul(x, x); // x^2
x3.val[i] = vec_mul(x2.val[i], x); // x^3
__vector float t = vec_madd(k_0_0447, x3.val[i], x); // x + 0.044715*x^3
inner.val[i] = vec_mul(k_s2pi, t); // √(2/π)*(...)
}
FP32Vec8 inner_vec(inner);
FP32Vec8 t = inner_vec.tanh(); // tanh part
FP32Vec8 one_vec(1.0f);
FP32Vec8 half_vec(0.5f);
FP32Vec8 x_vec(*this);
return x_vec * half_vec * (one_vec + t);
}
// Erf-based GELU:
// gelu(x) = 0.5 * x * (1 + erf(x / √2))
FP32Vec8 gelu_erf() const {
const __vector float inv_sqrt2 = vec_splats(0.7071067811865476f); // 1/√2
FP32Vec8 x_vec(*this);
f32x4x2_t scaled;
for (int i = 0; i < 2; ++i) {
scaled.val[i] = vec_mul(reg.val[i], inv_sqrt2);
}
FP32Vec8 x_scaled(scaled);
FP32Vec8 erf_x = x_scaled.er();
FP32Vec8 one_vec(1.0f);
FP32Vec8 half_vec(0.5f);
return x_vec * half_vec * (one_vec + erf_x);
}
// Elementwise reciprocal: 1/x (scalar per lane, for correctness)
FP32Vec8 rcp() const {
AliasReg in, out;
in.reg = reg;
for (int i = 0; i < VEC_ELEM_NUM; ++i) {
out.values[i] = 1.0f / in.values[i];
}
return FP32Vec8(out.reg);
}
// Elementwise rsqrt(x) = 1 / sqrt(x) (scalar per lane, for correctness)
FP32Vec8 rsqrt() const {
AliasReg in, out;
in.reg = reg;
for (int i = 0; i < VEC_ELEM_NUM; ++i) {
out.values[i] = 1.0f / std::sqrt(in.values[i]);
}
return FP32Vec8(out.reg);
}
FP32Vec8 operator*(const FP32Vec8& b) const {
@ -316,10 +524,11 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
}
explicit FP32Vec16(const BF16Vec16& v) {
reg.val[0] = (__vector float)vec_mergeh(zero, v.reg.val[0]);
reg.val[1] = (__vector float)vec_mergel(zero, v.reg.val[0]);
reg.val[2] = (__vector float)vec_mergeh(zero, v.reg.val[1]);
reg.val[3] = (__vector float)vec_mergel(zero, v.reg.val[1]);
// On big-endian s390x, place BF16 first to get correct byte order
reg.val[0] = (__vector float)vec_mergeh(v.reg.val[0], zero);
reg.val[1] = (__vector float)vec_mergel(v.reg.val[0], zero);
reg.val[2] = (__vector float)vec_mergeh(v.reg.val[1], zero);
reg.val[3] = (__vector float)vec_mergel(v.reg.val[1], zero);
}
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
@ -376,6 +585,23 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
return result;
}
FP32Vec16 max(const FP32Vec16& b) const {
return FP32Vec16(f32x4x4_t({vec_max(reg.val[0], b.reg.val[0]),
vec_max(reg.val[1], b.reg.val[1]),
vec_max(reg.val[2], b.reg.val[2]),
vec_max(reg.val[3], b.reg.val[3])}));
}
float reduce_max() const {
AliasReg ar;
ar.reg = reg;
float result = ar.values[0];
unroll_loop<int, VEC_ELEM_NUM>([&result, &ar](int i) {
if (ar.values[i] > result) result = ar.values[i];
});
return result;
}
void save(float* ptr) const {
vec_xst(reg.val[0], 0, ptr);
vec_xst(reg.val[1], 16, ptr);
@ -402,15 +628,14 @@ struct VecType<c10::BFloat16> {
using vec_type = BF16Vec8;
};
// On s390x, FP16 (Half) is not natively supported, use FP32 vectors instead
using FP16Vec16 = FP32Vec16;
template <typename T>
void storeFP32(float v, T* ptr) {
*ptr = v;
}
inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) {
acc = acc + a * b;
}
namespace c10 {
struct BFloat16 {
uint16_t value; // Assume BFloat16 is defined as a struct containing a 16-bit
@ -429,6 +654,79 @@ inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
#define __VEC_CLASS_FP_NAN (1 << 6)
#endif
// Optimized FMA (Fused Multiply-Add) implementations using IBM Z vector
// intrinsics
// FP32Vec4 FMA: acc = acc + (a * b) or equivalently acc = fma(a, b, acc)
FORCE_INLINE void fma(FP32Vec4& acc, const FP32Vec4& a, const FP32Vec4& b) {
acc.reg = vec_madd(a.reg, b.reg, acc.reg);
}
// FP32Vec8 FMA: acc = acc + (a * b)
FORCE_INLINE void fma(FP32Vec8& acc, const FP32Vec8& a, const FP32Vec8& b) {
acc.reg.val[0] = vec_madd(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
acc.reg.val[1] = vec_madd(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
}
// FP32Vec16 FMA: acc = acc + (a * b)
FORCE_INLINE void fma(FP32Vec16& acc, const FP32Vec16& a, const FP32Vec16& b) {
acc.reg.val[0] = vec_madd(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
acc.reg.val[1] = vec_madd(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
acc.reg.val[2] = vec_madd(a.reg.val[2], b.reg.val[2], acc.reg.val[2]);
acc.reg.val[3] = vec_madd(a.reg.val[3], b.reg.val[3], acc.reg.val[3]);
}
// Multiply-Subtract: acc = acc - (a * b)
FORCE_INLINE void fms(FP32Vec4& acc, const FP32Vec4& a, const FP32Vec4& b) {
acc.reg = vec_msub(a.reg, b.reg, acc.reg);
}
FORCE_INLINE void fms(FP32Vec8& acc, const FP32Vec8& a, const FP32Vec8& b) {
acc.reg.val[0] = vec_msub(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
acc.reg.val[1] = vec_msub(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
}
FORCE_INLINE void fms(FP32Vec16& acc, const FP32Vec16& a, const FP32Vec16& b) {
acc.reg.val[0] = vec_msub(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
acc.reg.val[1] = vec_msub(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
acc.reg.val[2] = vec_msub(a.reg.val[2], b.reg.val[2], acc.reg.val[2]);
acc.reg.val[3] = vec_msub(a.reg.val[3], b.reg.val[3], acc.reg.val[3]);
}
// Negative Multiply-Add: acc = -(a * b) + acc
FORCE_INLINE void nfma(FP32Vec4& acc, const FP32Vec4& a, const FP32Vec4& b) {
acc.reg = vec_nmadd(a.reg, b.reg, acc.reg);
}
FORCE_INLINE void nfma(FP32Vec8& acc, const FP32Vec8& a, const FP32Vec8& b) {
acc.reg.val[0] = vec_nmadd(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
acc.reg.val[1] = vec_nmadd(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
}
FORCE_INLINE void nfma(FP32Vec16& acc, const FP32Vec16& a, const FP32Vec16& b) {
acc.reg.val[0] = vec_nmadd(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
acc.reg.val[1] = vec_nmadd(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
acc.reg.val[2] = vec_nmadd(a.reg.val[2], b.reg.val[2], acc.reg.val[2]);
acc.reg.val[3] = vec_nmadd(a.reg.val[3], b.reg.val[3], acc.reg.val[3]);
}
// Negative Multiply-Subtract: acc = -(a * b) - acc
FORCE_INLINE void nfms(FP32Vec4& acc, const FP32Vec4& a, const FP32Vec4& b) {
acc.reg = vec_nmsub(a.reg, b.reg, acc.reg);
}
FORCE_INLINE void nfms(FP32Vec8& acc, const FP32Vec8& a, const FP32Vec8& b) {
acc.reg.val[0] = vec_nmsub(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
acc.reg.val[1] = vec_nmsub(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
}
FORCE_INLINE void nfms(FP32Vec16& acc, const FP32Vec16& a, const FP32Vec16& b) {
acc.reg.val[0] = vec_nmsub(a.reg.val[0], b.reg.val[0], acc.reg.val[0]);
acc.reg.val[1] = vec_nmsub(a.reg.val[1], b.reg.val[1], acc.reg.val[1]);
acc.reg.val[2] = vec_nmsub(a.reg.val[2], b.reg.val[2], acc.reg.val[2]);
acc.reg.val[3] = vec_nmsub(a.reg.val[3], b.reg.val[3], acc.reg.val[3]);
}
const static __vector unsigned char omask = {2, 3, 6, 7, 10, 11, 14, 15,
18, 19, 22, 23, 26, 27, 30, 31};
const static __vector unsigned int bias = {0x00007fff, 0x00007fff, 0x00007fff,
@ -441,13 +739,24 @@ const static __vector unsigned int one = {1, 1, 1, 1};
inline BF16Vec8::BF16Vec8(const FP32Vec8& v) {
__vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]);
__vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]);
__vector unsigned int lsb0 = inp0 >> sh16;
__vector unsigned int lsb1 = inp1 >> sh16;
lsb0 = lsb0 & one;
lsb1 = lsb1 & one;
__vector unsigned int rnd0 = lsb0 + bias;
__vector unsigned int rnd1 = lsb1 + bias;
inp0 = inp0 + rnd0;
inp1 = inp1 + rnd1;
int cc;
__vector __bool int sel0 =
vec_fp_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN, &cc);
__vector __bool int sel1 =
vec_fp_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN, &cc);
inp0 = vec_sel(inp0, nan, sel0) >> sh16;
inp1 = vec_sel(inp1, nan, sel1) >> sh16;
inp0 = vec_sel(inp0, nan, sel0);
inp1 = vec_sel(inp1, nan, sel1);
inp0 = inp0 >> sh16;
inp1 = inp1 >> sh16;
reg = (__vector signed short)vec_perm(inp0, inp1, omask);
}
@ -456,6 +765,22 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
__vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]);
__vector unsigned int inp2 = (__vector unsigned int)(v.reg.val[2]);
__vector unsigned int inp3 = (__vector unsigned int)(v.reg.val[3]);
__vector unsigned int lsb0 = inp0 >> sh16;
__vector unsigned int lsb1 = inp1 >> sh16;
__vector unsigned int lsb2 = inp2 >> sh16;
__vector unsigned int lsb3 = inp3 >> sh16;
lsb0 = lsb0 & one;
lsb1 = lsb1 & one;
lsb2 = lsb2 & one;
lsb3 = lsb3 & one;
__vector unsigned int rnd0 = lsb0 + bias;
__vector unsigned int rnd1 = lsb1 + bias;
__vector unsigned int rnd2 = lsb2 + bias;
__vector unsigned int rnd3 = lsb3 + bias;
inp0 = inp0 + rnd0;
inp1 = inp1 + rnd1;
inp2 = inp2 + rnd2;
inp3 = inp3 + rnd3;
int cc;
__vector __bool int sel0 =
vec_fp_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN, &cc);
@ -465,15 +790,164 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
vec_fp_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN, &cc);
__vector __bool int sel3 =
vec_fp_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN, &cc);
inp0 = vec_sel(inp0, nan, sel0) >> sh16;
inp1 = vec_sel(inp1, nan, sel1) >> sh16;
inp2 = vec_sel(inp2, nan, sel2) >> sh16;
inp3 = vec_sel(inp3, nan, sel3) >> sh16;
inp0 = vec_sel(inp0, nan, sel0);
inp1 = vec_sel(inp1, nan, sel1);
inp2 = vec_sel(inp2, nan, sel2);
inp3 = vec_sel(inp3, nan, sel3);
inp0 = inp0 >> sh16;
inp1 = inp1 >> sh16;
inp2 = inp2 >> sh16;
inp3 = inp3 >> sh16;
reg.val[0] = (__vector signed short)vec_perm(inp0, inp1, omask);
reg.val[1] = (__vector signed short)vec_perm(inp2, inp3, omask);
}
inline void prefetch(const void* addr) { void __dcbt(const void* addr); }
// 1D softmax over `n` elements in `input`, writes result to `output`.
// Uses FP32Vec8 for main body, scalar tail handling.
// Requirement: n > 0
FORCE_INLINE void softmax_fp32vec8(float* output, const float* input, int n) {
if (n <= 0) return;
// ---------- Pass 1: find max ----------
float max_val = -std::numeric_limits<float>::infinity();
int i = 0;
for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
FP32Vec8 v(input + i);
FP32Vec8::AliasReg ar;
ar.reg = v.reg;
for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
if (ar.values[j] > max_val) max_val = ar.values[j];
}
}
for (; i < n; ++i) {
if (input[i] > max_val) max_val = input[i];
}
// ---------- Pass 2: compute exp(x - max) and sum ----------
float sum = 0.0f;
i = 0;
for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
float tmp[FP32Vec8::VEC_ELEM_NUM];
for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
tmp[j] = input[i + j] - max_val;
}
FP32Vec8 v(tmp);
FP32Vec8 e = v.exp();
FP32Vec8::AliasReg ar;
ar.reg = e.reg;
for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
output[i + j] = ar.values[j];
sum += ar.values[j];
}
}
// Tail
for (; i < n; ++i) {
float x = input[i] - max_val;
float ex = std::exp(x); // scalar tail
output[i] = ex;
sum += ex;
}
// ---------- Pass 3: normalize ----------
float inv_sum = 1.0f / sum;
i = 0;
for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
float tmp[FP32Vec8::VEC_ELEM_NUM];
for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
tmp[j] = output[i + j] * inv_sum;
}
FP32Vec8 v(tmp);
v.save(output + i);
}
for (; i < n; ++i) {
output[i] *= inv_sum;
}
}
// 1D RMSNorm kernel:
// input: x[0..n-1]
// weight: w[0..n-1] (gamma), may be nullptr
// output: y[i] = x[i] * inv_rms * (weight[i] if weight != nullptr else 1)
// eps: small epsilon for numerical stability
FORCE_INLINE void rmsnorm_fp32vec8(float* output, const float* input,
const float* weight, int n, float eps) {
if (n <= 0) return;
// ---------- Pass 1: compute sum of squares ----------
float sum_sq = 0.0f;
int i = 0;
for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
FP32Vec8 x_vec(input + i);
FP32Vec8 sq = x_vec * x_vec;
FP32Vec8::AliasReg ar;
ar.reg = sq.reg;
for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
sum_sq += ar.values[j];
}
}
// Tail
for (; i < n; ++i) {
float v = input[i];
sum_sq += v * v;
}
float mean_sq = sum_sq / static_cast<float>(n);
float inv_rms = 1.0f / std::sqrt(mean_sq + eps);
// ---------- Pass 2: scale (and apply weight if given) ----------
const float inv_rms_f = inv_rms;
i = 0;
if (weight) {
// with gamma
for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
FP32Vec8 x_vec(input + i);
float wtmp[FP32Vec8::VEC_ELEM_NUM];
for (int j = 0; j < FP32Vec8::VEC_ELEM_NUM; ++j) {
wtmp[j] = weight[i + j];
}
FP32Vec8 w_vec(wtmp);
FP32Vec8 scale_vec(inv_rms_f);
FP32Vec8 y = x_vec * scale_vec * w_vec;
y.save(output + i);
}
for (; i < n; ++i) {
output[i] = input[i] * inv_rms_f * weight[i];
}
} else {
// without gamma
for (; i + FP32Vec8::VEC_ELEM_NUM <= n; i += FP32Vec8::VEC_ELEM_NUM) {
FP32Vec8 x_vec(input + i);
FP32Vec8 scale_vec(inv_rms_f);
FP32Vec8 y = x_vec * scale_vec;
y.save(output + i);
}
for (; i < n; ++i) {
output[i] = input[i] * inv_rms_f;
}
}
}
// Prefetch data to cache for better memory access performance
FORCE_INLINE void prefetch(const void* addr) {
__builtin_prefetch(addr, 0, 3); // 0=read, 3=high temporal locality
}
}; // namespace vec_op