mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 13:14:34 +08:00
refactor(cpu_types_scalar.hpp): Unify scalar loop implementations using unroll_loop (#28847)
Signed-off-by: ihb2032 <1355790728@qq.com> Co-authored-by: lyd1992 <liuyudong@iscas.ac.cn>
This commit is contained in:
parent
fdf93486d6
commit
8151609583
@ -26,10 +26,6 @@ namespace vec_op {
|
||||
|
||||
#define FORCE_INLINE __attribute__((always_inline)) inline
|
||||
|
||||
#define __max(a, b) ((a) > (b) ? (a) : (b))
|
||||
#define __min(a, b) ((a) < (b) ? (a) : (b))
|
||||
#define __abs(a) ((a) < (0) ? (0 - a) : (a))
|
||||
|
||||
typedef struct f16x8_t {
|
||||
uint16_t val[8];
|
||||
} f16x8_t;
|
||||
@ -99,7 +95,7 @@ struct FP16Vec16 : public Vec<FP16Vec16> {
|
||||
void save(void* ptr) const { *reinterpret_cast<f16x16_t*>(ptr) = reg; }
|
||||
|
||||
void save(void* ptr, const int elem_num) const {
|
||||
int num = __min(elem_num, VEC_ELEM_NUM);
|
||||
int num = std::min(elem_num, VEC_ELEM_NUM);
|
||||
std::memcpy(ptr, &(reg.val[0]), num * sizeof(uint16_t));
|
||||
}
|
||||
};
|
||||
@ -128,7 +124,7 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||
void save(void* ptr) const { *reinterpret_cast<f16x16_t*>(ptr) = reg; }
|
||||
|
||||
void save(void* ptr, const int elem_num) const {
|
||||
int num = __min(elem_num, VEC_ELEM_NUM);
|
||||
int num = std::min(elem_num, VEC_ELEM_NUM);
|
||||
std::memcpy(ptr, &(reg.val[0]), num * sizeof(uint16_t));
|
||||
}
|
||||
};
|
||||
@ -143,9 +139,9 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
explicit BF16Vec32(f16x32_t data) : reg(data) {};
|
||||
|
||||
explicit BF16Vec32(BF16Vec8& vec8_data) {
|
||||
for (int i = 0; i < VEC_ELEM_NUM; ++i) {
|
||||
unroll_loop<int, VEC_ELEM_NUM>([&vec8_data, this](int i) {
|
||||
reg.val[i] = vec8_data.reg.val[i % BF16Vec8::VEC_ELEM_NUM];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void save(void* ptr) const { *reinterpret_cast<f16x32_t*>(ptr) = reg; }
|
||||
@ -157,15 +153,11 @@ struct FP32Vec4 : public Vec<FP32Vec4> {
|
||||
f32x4_t reg;
|
||||
|
||||
explicit FP32Vec4(float v) {
|
||||
for (int i = 0; i < VEC_ELEM_NUM; ++i) {
|
||||
reg.val[i] = v;
|
||||
}
|
||||
unroll_loop<int, VEC_ELEM_NUM>([&v, this](int i) { reg.val[i] = v; });
|
||||
}
|
||||
|
||||
explicit FP32Vec4() {
|
||||
for (int i = 0; i < VEC_ELEM_NUM; ++i) {
|
||||
reg.val[i] = 0.0f;
|
||||
}
|
||||
unroll_loop<int, VEC_ELEM_NUM>([this](int i) { reg.val[i] = 0.0f; });
|
||||
}
|
||||
|
||||
explicit FP32Vec4(const float* ptr)
|
||||
@ -182,15 +174,11 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
f32x8_t reg;
|
||||
|
||||
explicit FP32Vec8(float v) {
|
||||
for (int i = 0; i < VEC_ELEM_NUM; ++i) {
|
||||
reg.val[i] = v;
|
||||
}
|
||||
unroll_loop<int, VEC_ELEM_NUM>([&v, this](int i) { reg.val[i] = v; });
|
||||
}
|
||||
|
||||
explicit FP32Vec8() {
|
||||
for (int i = 0; i < VEC_ELEM_NUM; ++i) {
|
||||
reg.val[i] = 0.0f;
|
||||
}
|
||||
unroll_loop<int, VEC_ELEM_NUM>([this](int i) { reg.val[i] = 0.0f; });
|
||||
}
|
||||
|
||||
explicit FP32Vec8(const float* ptr)
|
||||
@ -201,78 +189,68 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
explicit FP32Vec8(const FP32Vec8& data) : reg(data.reg) {};
|
||||
|
||||
explicit FP32Vec8(const FP16Vec8& v) {
|
||||
for (int i = 0; i < VEC_ELEM_NUM; ++i) {
|
||||
reg.val[i] = fp16_to_float(v.reg.val[i]);
|
||||
}
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&v, this](int i) { reg.val[i] = fp16_to_float(v.reg.val[i]); });
|
||||
}
|
||||
|
||||
FP32Vec8(const BF16Vec8& v) {
|
||||
for (int i = 0; i < VEC_ELEM_NUM; ++i) {
|
||||
reg.val[i] = bf16_to_float(v.reg.val[i]);
|
||||
}
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&v, this](int i) { reg.val[i] = bf16_to_float(v.reg.val[i]); });
|
||||
}
|
||||
|
||||
float reduce_sum() const {
|
||||
float result = 0;
|
||||
for (int i = 0; i < VEC_ELEM_NUM; ++i) {
|
||||
result += reg.val[i];
|
||||
}
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&result, this](int i) { result += reg.val[i]; });
|
||||
return result;
|
||||
}
|
||||
|
||||
FP32Vec8 exp() const {
|
||||
f32x8_t ret;
|
||||
for (int i = 0; i < VEC_ELEM_NUM; ++i) {
|
||||
ret.val[i] = expf(reg.val[i]);
|
||||
}
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&ret, this](int i) { ret.val[i] = expf(reg.val[i]); });
|
||||
return FP32Vec8(ret);
|
||||
}
|
||||
|
||||
FP32Vec8 tanh() const {
|
||||
f32x8_t ret;
|
||||
for (int i = 0; i < VEC_ELEM_NUM; ++i) {
|
||||
ret.val[i] = tanhf(reg.val[i]);
|
||||
}
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&ret, this](int i) { ret.val[i] = tanhf(reg.val[i]); });
|
||||
return FP32Vec8(ret);
|
||||
}
|
||||
|
||||
FP32Vec8 er() const {
|
||||
f32x8_t ret;
|
||||
for (int i = 0; i < VEC_ELEM_NUM; ++i) {
|
||||
ret.val[i] = erf(reg.val[i]);
|
||||
}
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&ret, this](int i) { ret.val[i] = erf(reg.val[i]); });
|
||||
return FP32Vec8(ret);
|
||||
}
|
||||
|
||||
FP32Vec8 operator*(const FP32Vec8& b) const {
|
||||
f32x8_t ret;
|
||||
for (int i = 0; i < VEC_ELEM_NUM; ++i) {
|
||||
ret.val[i] = reg.val[i] * b.reg.val[i];
|
||||
}
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&ret, &b, this](int i) { ret.val[i] = reg.val[i] * b.reg.val[i]; });
|
||||
return FP32Vec8(ret);
|
||||
}
|
||||
|
||||
FP32Vec8 operator+(const FP32Vec8& b) const {
|
||||
f32x8_t ret;
|
||||
for (int i = 0; i < VEC_ELEM_NUM; ++i) {
|
||||
ret.val[i] = reg.val[i] + b.reg.val[i];
|
||||
}
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&ret, &b, this](int i) { ret.val[i] = reg.val[i] + b.reg.val[i]; });
|
||||
return FP32Vec8(ret);
|
||||
}
|
||||
|
||||
FP32Vec8 operator-(const FP32Vec8& b) const {
|
||||
f32x8_t ret;
|
||||
for (int i = 0; i < VEC_ELEM_NUM; ++i) {
|
||||
ret.val[i] = reg.val[i] - b.reg.val[i];
|
||||
}
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&ret, &b, this](int i) { ret.val[i] = reg.val[i] - b.reg.val[i]; });
|
||||
return FP32Vec8(ret);
|
||||
}
|
||||
|
||||
FP32Vec8 operator/(const FP32Vec8& b) const {
|
||||
f32x8_t ret;
|
||||
for (int i = 0; i < VEC_ELEM_NUM; ++i) {
|
||||
ret.val[i] = reg.val[i] / b.reg.val[i];
|
||||
}
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&ret, &b, this](int i) { ret.val[i] = reg.val[i] / b.reg.val[i]; });
|
||||
return FP32Vec8(ret);
|
||||
}
|
||||
|
||||
@ -284,15 +262,11 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
f32x16_t reg;
|
||||
|
||||
explicit FP32Vec16(float v) {
|
||||
for (int i = 0; i < VEC_ELEM_NUM; ++i) {
|
||||
reg.val[i] = v;
|
||||
}
|
||||
unroll_loop<int, VEC_ELEM_NUM>([&v, this](int i) { reg.val[i] = v; });
|
||||
}
|
||||
|
||||
explicit FP32Vec16() {
|
||||
for (int i = 0; i < VEC_ELEM_NUM; ++i) {
|
||||
reg.val[i] = 0.0f;
|
||||
}
|
||||
unroll_loop<int, VEC_ELEM_NUM>([this](int i) { reg.val[i] = 0.0f; });
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const float* ptr)
|
||||
@ -301,29 +275,27 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
explicit FP32Vec16(f32x16_t data) : reg(data) {};
|
||||
|
||||
FP32Vec16(const FP32Vec4& data) {
|
||||
for (int i = 0; i < VEC_ELEM_NUM; ++i) {
|
||||
unroll_loop<int, VEC_ELEM_NUM>([&data, this](int i) {
|
||||
reg.val[i] = data.reg.val[i % FP32Vec4::VEC_ELEM_NUM];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
FP32Vec16(const FP32Vec8& data) {
|
||||
for (int i = 0; i < VEC_ELEM_NUM; ++i) {
|
||||
unroll_loop<int, VEC_ELEM_NUM>([&data, this](int i) {
|
||||
reg.val[i] = data.reg.val[i % FP32Vec8::VEC_ELEM_NUM];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
FP32Vec16(const FP32Vec16& data) : reg(data.reg) {};
|
||||
|
||||
explicit FP32Vec16(const FP16Vec16& v) {
|
||||
for (int i = 0; i < VEC_ELEM_NUM; ++i) {
|
||||
reg.val[i] = fp16_to_float(v.reg.val[i]);
|
||||
}
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&v, this](int i) { reg.val[i] = fp16_to_float(v.reg.val[i]); });
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const BF16Vec16& v) {
|
||||
for (int i = 0; i < VEC_ELEM_NUM; ++i) {
|
||||
reg.val[i] = bf16_to_float(v.reg.val[i]);
|
||||
}
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&v, this](int i) { reg.val[i] = bf16_to_float(v.reg.val[i]); });
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const FP16Vec8& v) : FP32Vec16(FP32Vec8(v)) {};
|
||||
@ -331,82 +303,74 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {};
|
||||
|
||||
FP32Vec16 operator*(const FP32Vec16& b) const {
|
||||
FP32Vec16 result(0.0f);
|
||||
for (int i = 0; i < VEC_ELEM_NUM; ++i) {
|
||||
result.reg.val[i] = reg.val[i] * b.reg.val[i];
|
||||
}
|
||||
return result;
|
||||
f32x16_t ret;
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&ret, &b, this](int i) { ret.val[i] = reg.val[i] * b.reg.val[i]; });
|
||||
return FP32Vec16(ret);
|
||||
}
|
||||
|
||||
FP32Vec16 operator+(const FP32Vec16& b) const {
|
||||
FP32Vec16 result(0.0f);
|
||||
for (int i = 0; i < VEC_ELEM_NUM; ++i) {
|
||||
result.reg.val[i] = reg.val[i] + b.reg.val[i];
|
||||
}
|
||||
return result;
|
||||
f32x16_t ret;
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&ret, &b, this](int i) { ret.val[i] = reg.val[i] + b.reg.val[i]; });
|
||||
return FP32Vec16(ret);
|
||||
}
|
||||
|
||||
FP32Vec16 operator-(const FP32Vec16& b) const {
|
||||
FP32Vec16 result(0.0f);
|
||||
for (int i = 0; i < VEC_ELEM_NUM; ++i) {
|
||||
result.reg.val[i] = reg.val[i] - b.reg.val[i];
|
||||
}
|
||||
return result;
|
||||
f32x16_t ret;
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&ret, &b, this](int i) { ret.val[i] = reg.val[i] - b.reg.val[i]; });
|
||||
return FP32Vec16(ret);
|
||||
}
|
||||
|
||||
FP32Vec16 operator/(const FP32Vec16& b) const {
|
||||
FP32Vec16 result(0.0f);
|
||||
for (int i = 0; i < VEC_ELEM_NUM; ++i) {
|
||||
result.reg.val[i] = reg.val[i] / b.reg.val[i];
|
||||
}
|
||||
return result;
|
||||
f32x16_t ret;
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&ret, &b, this](int i) { ret.val[i] = reg.val[i] / b.reg.val[i]; });
|
||||
return FP32Vec16(ret);
|
||||
}
|
||||
|
||||
FP32Vec16 max(const FP32Vec16& b) const {
|
||||
FP32Vec16 result(0.0f);
|
||||
for (int i = 0; i < VEC_ELEM_NUM; ++i) {
|
||||
result.reg.val[i] = __max(reg.val[i], b.reg.val[i]);
|
||||
}
|
||||
return result;
|
||||
f32x16_t ret;
|
||||
unroll_loop<int, VEC_ELEM_NUM>([&ret, &b, this](int i) {
|
||||
ret.val[i] = std::max(reg.val[i], b.reg.val[i]);
|
||||
});
|
||||
return FP32Vec16(ret);
|
||||
}
|
||||
|
||||
FP32Vec16 min(const FP32Vec16& b) const {
|
||||
FP32Vec16 result(0.0f);
|
||||
for (int i = 0; i < VEC_ELEM_NUM; ++i) {
|
||||
result.reg.val[i] = __min(reg.val[i], b.reg.val[i]);
|
||||
}
|
||||
return result;
|
||||
f32x16_t ret;
|
||||
unroll_loop<int, VEC_ELEM_NUM>([&ret, &b, this](int i) {
|
||||
ret.val[i] = std::min(reg.val[i], b.reg.val[i]);
|
||||
});
|
||||
return FP32Vec16(ret);
|
||||
}
|
||||
|
||||
FP32Vec16 abs() const {
|
||||
FP32Vec16 result(0.0f);
|
||||
for (int i = 0; i < VEC_ELEM_NUM; ++i) {
|
||||
result.reg.val[i] = __abs(reg.val[i]);
|
||||
}
|
||||
return result;
|
||||
f32x16_t ret;
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&ret, this](int i) { ret.val[i] = std::abs(reg.val[i]); });
|
||||
return FP32Vec16(ret);
|
||||
}
|
||||
|
||||
float reduce_sum() const {
|
||||
float result = 0.0f;
|
||||
for (int i = 0; i < VEC_ELEM_NUM; ++i) {
|
||||
result += reg.val[i];
|
||||
}
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&result, this](int i) { result += reg.val[i]; });
|
||||
return result;
|
||||
}
|
||||
|
||||
float reduce_max() const {
|
||||
float result = reg.val[0];
|
||||
for (int i = 0; i < VEC_ELEM_NUM; ++i) {
|
||||
result = __max(reg.val[i], result);
|
||||
}
|
||||
float result = std::numeric_limits<float>::lowest();
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&result, this](int i) { result = std::max(reg.val[i], result); });
|
||||
return result;
|
||||
}
|
||||
|
||||
float reduce_min() const {
|
||||
float result = reg.val[0];
|
||||
for (int i = 0; i < VEC_ELEM_NUM; ++i) {
|
||||
result = __min(reg.val[i], result);
|
||||
}
|
||||
float result = std::numeric_limits<float>::max();
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&result, this](int i) { result = std::min(reg.val[i], result); });
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -414,13 +378,9 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
float reduce_sub_sum(int idx) {
|
||||
static_assert(VEC_ELEM_NUM % group_size == 0);
|
||||
float sum = 0.0;
|
||||
int start = idx * group_size;
|
||||
int end = (idx + 1) * group_size;
|
||||
|
||||
for (; (start < VEC_ELEM_NUM) && (start < end); ++start) {
|
||||
sum += reg.val[start];
|
||||
}
|
||||
|
||||
const int start = idx * group_size;
|
||||
unroll_loop<int, group_size>(
|
||||
[&sum, &start, this](int i) { sum += reg.val[start + i]; });
|
||||
return sum;
|
||||
}
|
||||
|
||||
@ -477,17 +437,13 @@ inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
|
||||
}
|
||||
|
||||
inline FP16Vec16::FP16Vec16(const FP32Vec16& v) {
|
||||
int i = 0;
|
||||
for (i = 0; i < FP16Vec16::VEC_ELEM_NUM; ++i) {
|
||||
reg.val[i] = float_to_fp16(v.reg.val[i]);
|
||||
}
|
||||
unroll_loop<int, FP16Vec16::VEC_ELEM_NUM>(
|
||||
[&v, this](int i) { reg.val[i] = float_to_fp16(v.reg.val[i]); });
|
||||
}
|
||||
|
||||
inline FP16Vec8 ::FP16Vec8(const FP32Vec8& v) {
|
||||
int i = 0;
|
||||
for (i = 0; i < FP16Vec8::VEC_ELEM_NUM; ++i) {
|
||||
reg.val[i] = float_to_fp16(v.reg.val[i]);
|
||||
}
|
||||
unroll_loop<int, FP16Vec8::VEC_ELEM_NUM>(
|
||||
[&v, this](int i) { reg.val[i] = float_to_fp16(v.reg.val[i]); });
|
||||
}
|
||||
|
||||
inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) {
|
||||
@ -495,17 +451,13 @@ inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) {
|
||||
}
|
||||
|
||||
inline BF16Vec8::BF16Vec8(const FP32Vec8& v) {
|
||||
int i = 0;
|
||||
for (i = 0; i < BF16Vec8::VEC_ELEM_NUM; ++i) {
|
||||
reg.val[i] = float_to_bf16(v.reg.val[i]);
|
||||
}
|
||||
unroll_loop<int, BF16Vec8::VEC_ELEM_NUM>(
|
||||
[&v, this](int i) { reg.val[i] = float_to_bf16(v.reg.val[i]); });
|
||||
}
|
||||
|
||||
inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
|
||||
int i = 0;
|
||||
for (i = 0; i < BF16Vec16::VEC_ELEM_NUM; ++i) {
|
||||
reg.val[i] = float_to_bf16(v.reg.val[i]);
|
||||
}
|
||||
unroll_loop<int, BF16Vec16::VEC_ELEM_NUM>(
|
||||
[&v, this](int i) { reg.val[i] = float_to_bf16(v.reg.val[i]); });
|
||||
}
|
||||
|
||||
inline void prefetch(const void* addr) { __builtin_prefetch(addr, 0, 3); }
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user