diff --git a/csrc/cpu/cpu_types_scalar.hpp b/csrc/cpu/cpu_types_scalar.hpp index 1a9278bc662e5..f9da78283da5e 100644 --- a/csrc/cpu/cpu_types_scalar.hpp +++ b/csrc/cpu/cpu_types_scalar.hpp @@ -26,10 +26,6 @@ namespace vec_op { #define FORCE_INLINE __attribute__((always_inline)) inline -#define __max(a, b) ((a) > (b) ? (a) : (b)) -#define __min(a, b) ((a) < (b) ? (a) : (b)) -#define __abs(a) ((a) < (0) ? (0 - a) : (a)) - typedef struct f16x8_t { uint16_t val[8]; } f16x8_t; @@ -99,7 +95,7 @@ struct FP16Vec16 : public Vec { void save(void* ptr) const { *reinterpret_cast(ptr) = reg; } void save(void* ptr, const int elem_num) const { - int num = __min(elem_num, VEC_ELEM_NUM); + int num = std::min(elem_num, VEC_ELEM_NUM); std::memcpy(ptr, &(reg.val[0]), num * sizeof(uint16_t)); } }; @@ -128,7 +124,7 @@ struct BF16Vec16 : public Vec { void save(void* ptr) const { *reinterpret_cast(ptr) = reg; } void save(void* ptr, const int elem_num) const { - int num = __min(elem_num, VEC_ELEM_NUM); + int num = std::min(elem_num, VEC_ELEM_NUM); std::memcpy(ptr, &(reg.val[0]), num * sizeof(uint16_t)); } }; @@ -143,9 +139,9 @@ struct BF16Vec32 : public Vec { explicit BF16Vec32(f16x32_t data) : reg(data) {}; explicit BF16Vec32(BF16Vec8& vec8_data) { - for (int i = 0; i < VEC_ELEM_NUM; ++i) { + unroll_loop([&vec8_data, this](int i) { reg.val[i] = vec8_data.reg.val[i % BF16Vec8::VEC_ELEM_NUM]; - } + }); } void save(void* ptr) const { *reinterpret_cast(ptr) = reg; } @@ -157,15 +153,11 @@ struct FP32Vec4 : public Vec { f32x4_t reg; explicit FP32Vec4(float v) { - for (int i = 0; i < VEC_ELEM_NUM; ++i) { - reg.val[i] = v; - } + unroll_loop([&v, this](int i) { reg.val[i] = v; }); } explicit FP32Vec4() { - for (int i = 0; i < VEC_ELEM_NUM; ++i) { - reg.val[i] = 0.0f; - } + unroll_loop([this](int i) { reg.val[i] = 0.0f; }); } explicit FP32Vec4(const float* ptr) @@ -182,15 +174,11 @@ struct FP32Vec8 : public Vec { f32x8_t reg; explicit FP32Vec8(float v) { - for (int i = 0; i < VEC_ELEM_NUM; ++i) { - reg.val[i] = v; - } + unroll_loop([&v, this](int i) { reg.val[i] = v; }); } explicit FP32Vec8() { - for (int i = 0; i < VEC_ELEM_NUM; ++i) { - reg.val[i] = 0.0f; - } + unroll_loop([this](int i) { reg.val[i] = 0.0f; }); } explicit FP32Vec8(const float* ptr) @@ -201,78 +189,68 @@ struct FP32Vec8 : public Vec { explicit FP32Vec8(const FP32Vec8& data) : reg(data.reg) {}; explicit FP32Vec8(const FP16Vec8& v) { - for (int i = 0; i < VEC_ELEM_NUM; ++i) { - reg.val[i] = fp16_to_float(v.reg.val[i]); - } + unroll_loop( + [&v, this](int i) { reg.val[i] = fp16_to_float(v.reg.val[i]); }); } FP32Vec8(const BF16Vec8& v) { - for (int i = 0; i < VEC_ELEM_NUM; ++i) { - reg.val[i] = bf16_to_float(v.reg.val[i]); - } + unroll_loop( + [&v, this](int i) { reg.val[i] = bf16_to_float(v.reg.val[i]); }); } float reduce_sum() const { float result = 0; - for (int i = 0; i < VEC_ELEM_NUM; ++i) { - result += reg.val[i]; - } + unroll_loop( + [&result, this](int i) { result += reg.val[i]; }); return result; } FP32Vec8 exp() const { f32x8_t ret; - for (int i = 0; i < VEC_ELEM_NUM; ++i) { - ret.val[i] = expf(reg.val[i]); - } + unroll_loop( + [&ret, this](int i) { ret.val[i] = expf(reg.val[i]); }); return FP32Vec8(ret); } FP32Vec8 tanh() const { f32x8_t ret; - for (int i = 0; i < VEC_ELEM_NUM; ++i) { - ret.val[i] = tanhf(reg.val[i]); - } + unroll_loop( + [&ret, this](int i) { ret.val[i] = tanhf(reg.val[i]); }); return FP32Vec8(ret); } FP32Vec8 er() const { f32x8_t ret; - for (int i = 0; i < VEC_ELEM_NUM; ++i) { - ret.val[i] = erf(reg.val[i]); - } + unroll_loop( + [&ret, this](int i) { ret.val[i] = erf(reg.val[i]); }); return FP32Vec8(ret); } FP32Vec8 operator*(const FP32Vec8& b) const { f32x8_t ret; - for (int i = 0; i < VEC_ELEM_NUM; ++i) { - ret.val[i] = reg.val[i] * b.reg.val[i]; - } + unroll_loop( + [&ret, &b, this](int i) { ret.val[i] = reg.val[i] * b.reg.val[i]; }); return FP32Vec8(ret); } FP32Vec8 operator+(const FP32Vec8& b) const { f32x8_t ret; - for (int i = 0; i < VEC_ELEM_NUM; ++i) { - ret.val[i] = reg.val[i] + b.reg.val[i]; - } + unroll_loop( + [&ret, &b, this](int i) { ret.val[i] = reg.val[i] + b.reg.val[i]; }); return FP32Vec8(ret); } FP32Vec8 operator-(const FP32Vec8& b) const { f32x8_t ret; - for (int i = 0; i < VEC_ELEM_NUM; ++i) { - ret.val[i] = reg.val[i] - b.reg.val[i]; - } + unroll_loop( + [&ret, &b, this](int i) { ret.val[i] = reg.val[i] - b.reg.val[i]; }); return FP32Vec8(ret); } FP32Vec8 operator/(const FP32Vec8& b) const { f32x8_t ret; - for (int i = 0; i < VEC_ELEM_NUM; ++i) { - ret.val[i] = reg.val[i] / b.reg.val[i]; - } + unroll_loop( + [&ret, &b, this](int i) { ret.val[i] = reg.val[i] / b.reg.val[i]; }); return FP32Vec8(ret); } @@ -284,15 +262,11 @@ struct FP32Vec16 : public Vec { f32x16_t reg; explicit FP32Vec16(float v) { - for (int i = 0; i < VEC_ELEM_NUM; ++i) { - reg.val[i] = v; - } + unroll_loop([&v, this](int i) { reg.val[i] = v; }); } explicit FP32Vec16() { - for (int i = 0; i < VEC_ELEM_NUM; ++i) { - reg.val[i] = 0.0f; - } + unroll_loop([this](int i) { reg.val[i] = 0.0f; }); } explicit FP32Vec16(const float* ptr) @@ -301,29 +275,27 @@ struct FP32Vec16 : public Vec { explicit FP32Vec16(f32x16_t data) : reg(data) {}; FP32Vec16(const FP32Vec4& data) { - for (int i = 0; i < VEC_ELEM_NUM; ++i) { + unroll_loop([&data, this](int i) { reg.val[i] = data.reg.val[i % FP32Vec4::VEC_ELEM_NUM]; - } + }); } FP32Vec16(const FP32Vec8& data) { - for (int i = 0; i < VEC_ELEM_NUM; ++i) { + unroll_loop([&data, this](int i) { reg.val[i] = data.reg.val[i % FP32Vec8::VEC_ELEM_NUM]; - } + }); } FP32Vec16(const FP32Vec16& data) : reg(data.reg) {}; explicit FP32Vec16(const FP16Vec16& v) { - for (int i = 0; i < VEC_ELEM_NUM; ++i) { - reg.val[i] = fp16_to_float(v.reg.val[i]); - } + unroll_loop( + [&v, this](int i) { reg.val[i] = fp16_to_float(v.reg.val[i]); }); } explicit FP32Vec16(const BF16Vec16& v) { - for (int i = 0; i < VEC_ELEM_NUM; ++i) { - reg.val[i] = bf16_to_float(v.reg.val[i]); - } + unroll_loop( + [&v, this](int i) { reg.val[i] = bf16_to_float(v.reg.val[i]); }); } explicit FP32Vec16(const FP16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}; @@ -331,82 +303,74 @@ struct FP32Vec16 : public Vec { FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}; FP32Vec16 operator*(const FP32Vec16& b) const { - FP32Vec16 result(0.0f); - for (int i = 0; i < VEC_ELEM_NUM; ++i) { - result.reg.val[i] = reg.val[i] * b.reg.val[i]; - } - return result; + f32x16_t ret; + unroll_loop( + [&ret, &b, this](int i) { ret.val[i] = reg.val[i] * b.reg.val[i]; }); + return FP32Vec16(ret); } FP32Vec16 operator+(const FP32Vec16& b) const { - FP32Vec16 result(0.0f); - for (int i = 0; i < VEC_ELEM_NUM; ++i) { - result.reg.val[i] = reg.val[i] + b.reg.val[i]; - } - return result; + f32x16_t ret; + unroll_loop( + [&ret, &b, this](int i) { ret.val[i] = reg.val[i] + b.reg.val[i]; }); + return FP32Vec16(ret); } FP32Vec16 operator-(const FP32Vec16& b) const { - FP32Vec16 result(0.0f); - for (int i = 0; i < VEC_ELEM_NUM; ++i) { - result.reg.val[i] = reg.val[i] - b.reg.val[i]; - } - return result; + f32x16_t ret; + unroll_loop( + [&ret, &b, this](int i) { ret.val[i] = reg.val[i] - b.reg.val[i]; }); + return FP32Vec16(ret); } FP32Vec16 operator/(const FP32Vec16& b) const { - FP32Vec16 result(0.0f); - for (int i = 0; i < VEC_ELEM_NUM; ++i) { - result.reg.val[i] = reg.val[i] / b.reg.val[i]; - } - return result; + f32x16_t ret; + unroll_loop( + [&ret, &b, this](int i) { ret.val[i] = reg.val[i] / b.reg.val[i]; }); + return FP32Vec16(ret); } FP32Vec16 max(const FP32Vec16& b) const { - FP32Vec16 result(0.0f); - for (int i = 0; i < VEC_ELEM_NUM; ++i) { - result.reg.val[i] = __max(reg.val[i], b.reg.val[i]); - } - return result; + f32x16_t ret; + unroll_loop([&ret, &b, this](int i) { + ret.val[i] = std::max(reg.val[i], b.reg.val[i]); + }); + return FP32Vec16(ret); } FP32Vec16 min(const FP32Vec16& b) const { - FP32Vec16 result(0.0f); - for (int i = 0; i < VEC_ELEM_NUM; ++i) { - result.reg.val[i] = __min(reg.val[i], b.reg.val[i]); - } - return result; + f32x16_t ret; + unroll_loop([&ret, &b, this](int i) { + ret.val[i] = std::min(reg.val[i], b.reg.val[i]); + }); + return FP32Vec16(ret); } FP32Vec16 abs() const { - FP32Vec16 result(0.0f); - for (int i = 0; i < VEC_ELEM_NUM; ++i) { - result.reg.val[i] = __abs(reg.val[i]); - } - return result; + f32x16_t ret; + unroll_loop( + [&ret, this](int i) { ret.val[i] = std::abs(reg.val[i]); }); + return FP32Vec16(ret); } float reduce_sum() const { float result = 0.0f; - for (int i = 0; i < VEC_ELEM_NUM; ++i) { - result += reg.val[i]; - } + unroll_loop( + [&result, this](int i) { result += reg.val[i]; }); return result; } float reduce_max() const { - float result = reg.val[0]; - for (int i = 0; i < VEC_ELEM_NUM; ++i) { - result = __max(reg.val[i], result); - } + float result = std::numeric_limits::lowest(); + unroll_loop( + [&result, this](int i) { result = std::max(reg.val[i], result); }); return result; } float reduce_min() const { - float result = reg.val[0]; - for (int i = 0; i < VEC_ELEM_NUM; ++i) { - result = __min(reg.val[i], result); - } + float result = std::numeric_limits::max(); + unroll_loop( + [&result, this](int i) { result = std::min(reg.val[i], result); }); return result; } @@ -414,13 +378,9 @@ struct FP32Vec16 : public Vec { float reduce_sub_sum(int idx) { static_assert(VEC_ELEM_NUM % group_size == 0); float sum = 0.0; - int start = idx * group_size; - int end = (idx + 1) * group_size; - - for (; (start < VEC_ELEM_NUM) && (start < end); ++start) { - sum += reg.val[start]; - } - + const int start = idx * group_size; + unroll_loop( + [&sum, &start, this](int i) { sum += reg.val[start + i]; }); return sum; } @@ -477,17 +437,13 @@ inline void storeFP32(float v, c10::BFloat16* ptr) { } inline FP16Vec16::FP16Vec16(const FP32Vec16& v) { - int i = 0; - for (i = 0; i < FP16Vec16::VEC_ELEM_NUM; ++i) { - reg.val[i] = float_to_fp16(v.reg.val[i]); - } + unroll_loop( + [&v, this](int i) { reg.val[i] = float_to_fp16(v.reg.val[i]); }); } inline FP16Vec8 ::FP16Vec8(const FP32Vec8& v) { - int i = 0; - for (i = 0; i < FP16Vec8::VEC_ELEM_NUM; ++i) { - reg.val[i] = float_to_fp16(v.reg.val[i]); - } + unroll_loop( + [&v, this](int i) { reg.val[i] = float_to_fp16(v.reg.val[i]); }); } inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) { @@ -495,17 +451,13 @@ inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) { } inline BF16Vec8::BF16Vec8(const FP32Vec8& v) { - int i = 0; - for (i = 0; i < BF16Vec8::VEC_ELEM_NUM; ++i) { - reg.val[i] = float_to_bf16(v.reg.val[i]); - } + unroll_loop( + [&v, this](int i) { reg.val[i] = float_to_bf16(v.reg.val[i]); }); } inline BF16Vec16::BF16Vec16(const FP32Vec16& v) { - int i = 0; - for (i = 0; i < BF16Vec16::VEC_ELEM_NUM; ++i) { - reg.val[i] = float_to_bf16(v.reg.val[i]); - } + unroll_loop( + [&v, this](int i) { reg.val[i] = float_to_bf16(v.reg.val[i]); }); } inline void prefetch(const void* addr) { __builtin_prefetch(addr, 0, 3); }