From a88371f84e408bce591076e1711ba5e3b67250ea Mon Sep 17 00:00:00 2001 From: chenlang Date: Thu, 25 Sep 2025 20:46:11 +0800 Subject: [PATCH] [Hardware][RISC-V] Add riscv64 support for vLLM with scalar (#22112) Signed-off-by: chenlang Co-authored-by: chenlang <10346245@zte.com.cn> Signed-off-by: yewentao256 --- cmake/cpu_extension.cmake | 9 +- csrc/cpu/cpu_types.hpp | 3 +- csrc/cpu/cpu_types_scalar.hpp | 513 ++++++++++++++++++++++++++++++++++ csrc/cpu/float_convert.hpp | 106 +++++++ vllm/platforms/interface.py | 3 + 5 files changed, 632 insertions(+), 2 deletions(-) create mode 100644 csrc/cpu/cpu_types_scalar.hpp create mode 100644 csrc/cpu/float_convert.hpp diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 2a2ec08f86951..e6d0012c1a4be 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -101,6 +101,7 @@ else() find_isa(${CPUINFO} "asimd" ASIMD_FOUND) # Check for ARM NEON support find_isa(${CPUINFO} "bf16" ARM_BF16_FOUND) # Check for ARM BF16 support find_isa(${CPUINFO} "S390" S390_FOUND) + find_isa(${CPUINFO} "v" RVV_FOUND) # Check for RISC-V RVV support endif() if (AVX512_FOUND AND NOT AVX512_DISABLED) @@ -177,8 +178,14 @@ elseif (S390_FOUND) "-mzvector" "-march=native" "-mtune=native") +elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "riscv64") + if(RVV_FOUND) + message(FAIL_ERROR "Can't support rvv now.") + else() + list(APPEND CXX_COMPILE_FLAGS "-march=rv64gc") + endif() else() - message(FATAL_ERROR "vLLM CPU backend requires AVX512, AVX2, Power9+ ISA, S390X ISA or ARMv8 support.") + message(FATAL_ERROR "vLLM CPU backend requires AVX512, AVX2, Power9+ ISA, S390X ISA, ARMv8 or RISC-V support.") endif() # diff --git a/csrc/cpu/cpu_types.hpp b/csrc/cpu/cpu_types.hpp index c3a21796881c9..9cdcd2edacfdb 100644 --- a/csrc/cpu/cpu_types.hpp +++ b/csrc/cpu/cpu_types.hpp @@ -14,7 +14,8 @@ // arm implementation #include "cpu_types_arm.hpp" #else - #warning "unsupported vLLM cpu implementation" + #warning "unsupported vLLM cpu implementation, vLLM will compile with scalar" + #include "cpu_types_scalar.hpp" #endif #ifdef _OPENMP diff --git a/csrc/cpu/cpu_types_scalar.hpp b/csrc/cpu/cpu_types_scalar.hpp new file mode 100644 index 0000000000000..1a9278bc662e5 --- /dev/null +++ b/csrc/cpu/cpu_types_scalar.hpp @@ -0,0 +1,513 @@ +#include +#include +#include +#include +#include "float_convert.hpp" + +namespace vec_op { + +#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) + +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) + +#ifndef CPU_OP_GUARD + #define CPU_KERNEL_GUARD_IN(NAME) + #define CPU_KERNEL_GUARD_OUT(NAME) +#else + #define CPU_KERNEL_GUARD_IN(NAME) \ + std::cout << #NAME << " invoked." << std::endl; + #define CPU_KERNEL_GUARD_OUT(NAME) \ + std::cout << #NAME << " exit." << std::endl; +#endif + +#define FORCE_INLINE __attribute__((always_inline)) inline + +#define __max(a, b) ((a) > (b) ? (a) : (b)) +#define __min(a, b) ((a) < (b) ? (a) : (b)) +#define __abs(a) ((a) < (0) ? (0 - a) : (a)) + +typedef struct f16x8_t { + uint16_t val[8]; +} f16x8_t; + +typedef struct f16x16_t { + uint16_t val[16]; +} f16x16_t; + +typedef struct f16x32_t { + uint16_t val[32]; +} f16x32_t; + +typedef struct f32x4_t { + float val[4]; +} f32x4_t; + +typedef struct f32x8_t { + float val[8]; +} f32x8_t; + +typedef struct f32x16_t { + float val[16]; +} f32x16_t; + +namespace { +template +constexpr void unroll_loop_item(std::integer_sequence, F&& f) { + (f(std::integral_constant{}), ...); +}; +}; // namespace + +template > > +constexpr void unroll_loop(F&& f) { + unroll_loop_item(std::make_integer_sequence{}, std::forward(f)); +} + +template +struct Vec { + constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; } +}; + +struct FP32Vec8; +struct FP32Vec16; + +struct FP16Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + f16x8_t reg; + + explicit FP16Vec8(const void* ptr) + : reg(*reinterpret_cast(ptr)) {}; + + explicit FP16Vec8(const FP32Vec8&); + + void save(void* ptr) const { *reinterpret_cast(ptr) = reg; } +}; + +struct FP16Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + f16x16_t reg; + + explicit FP16Vec16(const void* ptr) + : reg(*reinterpret_cast(ptr)) {}; + + explicit FP16Vec16(const FP32Vec16&); + + void save(void* ptr) const { *reinterpret_cast(ptr) = reg; } + + void save(void* ptr, const int elem_num) const { + int num = __min(elem_num, VEC_ELEM_NUM); + std::memcpy(ptr, &(reg.val[0]), num * sizeof(uint16_t)); + } +}; + +struct BF16Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + f16x8_t reg; + + explicit BF16Vec8(const void* ptr) + : reg(*reinterpret_cast(ptr)) {}; + + explicit BF16Vec8(const FP32Vec8&); + + void save(void* ptr) const { *reinterpret_cast(ptr) = reg; } +}; + +struct BF16Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + f16x16_t reg; + + explicit BF16Vec16(const void* ptr) + : reg(*reinterpret_cast(ptr)) {}; + + explicit BF16Vec16(const FP32Vec16&); + + void save(void* ptr) const { *reinterpret_cast(ptr) = reg; } + + void save(void* ptr, const int elem_num) const { + int num = __min(elem_num, VEC_ELEM_NUM); + std::memcpy(ptr, &(reg.val[0]), num * sizeof(uint16_t)); + } +}; + +struct BF16Vec32 : public Vec { + constexpr static int VEC_ELEM_NUM = 32; + f16x32_t reg; + + explicit BF16Vec32(const void* ptr) + : reg(*reinterpret_cast(ptr)) {}; + + explicit BF16Vec32(f16x32_t data) : reg(data) {}; + + explicit BF16Vec32(BF16Vec8& vec8_data) { + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + reg.val[i] = vec8_data.reg.val[i % BF16Vec8::VEC_ELEM_NUM]; + } + } + + void save(void* ptr) const { *reinterpret_cast(ptr) = reg; } +}; + +struct FP32Vec4 : public Vec { + constexpr static int VEC_ELEM_NUM = 4; + + f32x4_t reg; + + explicit FP32Vec4(float v) { + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + reg.val[i] = v; + } + } + + explicit FP32Vec4() { + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + reg.val[i] = 0.0f; + } + } + + explicit FP32Vec4(const float* ptr) + : reg(*reinterpret_cast(ptr)) {}; + + explicit FP32Vec4(f32x4_t data) : reg(data) {}; + + explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {}; +}; + +struct FP32Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + + f32x8_t reg; + + explicit FP32Vec8(float v) { + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + reg.val[i] = v; + } + } + + explicit FP32Vec8() { + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + reg.val[i] = 0.0f; + } + } + + explicit FP32Vec8(const float* ptr) + : reg(*reinterpret_cast(ptr)) {}; + + explicit FP32Vec8(f32x8_t data) : reg(data) {}; + + explicit FP32Vec8(const FP32Vec8& data) : reg(data.reg) {}; + + explicit FP32Vec8(const FP16Vec8& v) { + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + reg.val[i] = fp16_to_float(v.reg.val[i]); + } + } + + FP32Vec8(const BF16Vec8& v) { + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + reg.val[i] = bf16_to_float(v.reg.val[i]); + } + } + + float reduce_sum() const { + float result = 0; + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + result += reg.val[i]; + } + return result; + } + + FP32Vec8 exp() const { + f32x8_t ret; + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + ret.val[i] = expf(reg.val[i]); + } + return FP32Vec8(ret); + } + + FP32Vec8 tanh() const { + f32x8_t ret; + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + ret.val[i] = tanhf(reg.val[i]); + } + return FP32Vec8(ret); + } + + FP32Vec8 er() const { + f32x8_t ret; + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + ret.val[i] = erf(reg.val[i]); + } + return FP32Vec8(ret); + } + + FP32Vec8 operator*(const FP32Vec8& b) const { + f32x8_t ret; + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + ret.val[i] = reg.val[i] * b.reg.val[i]; + } + return FP32Vec8(ret); + } + + FP32Vec8 operator+(const FP32Vec8& b) const { + f32x8_t ret; + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + ret.val[i] = reg.val[i] + b.reg.val[i]; + } + return FP32Vec8(ret); + } + + FP32Vec8 operator-(const FP32Vec8& b) const { + f32x8_t ret; + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + ret.val[i] = reg.val[i] - b.reg.val[i]; + } + return FP32Vec8(ret); + } + + FP32Vec8 operator/(const FP32Vec8& b) const { + f32x8_t ret; + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + ret.val[i] = reg.val[i] / b.reg.val[i]; + } + return FP32Vec8(ret); + } + + void save(void* ptr) const { *reinterpret_cast(ptr) = reg; } +}; + +struct FP32Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + f32x16_t reg; + + explicit FP32Vec16(float v) { + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + reg.val[i] = v; + } + } + + explicit FP32Vec16() { + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + reg.val[i] = 0.0f; + } + } + + explicit FP32Vec16(const float* ptr) + : reg(*reinterpret_cast(ptr)) {}; + + explicit FP32Vec16(f32x16_t data) : reg(data) {}; + + FP32Vec16(const FP32Vec4& data) { + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + reg.val[i] = data.reg.val[i % FP32Vec4::VEC_ELEM_NUM]; + } + } + + FP32Vec16(const FP32Vec8& data) { + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + reg.val[i] = data.reg.val[i % FP32Vec8::VEC_ELEM_NUM]; + } + } + + FP32Vec16(const FP32Vec16& data) : reg(data.reg) {}; + + explicit FP32Vec16(const FP16Vec16& v) { + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + reg.val[i] = fp16_to_float(v.reg.val[i]); + } + } + + explicit FP32Vec16(const BF16Vec16& v) { + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + reg.val[i] = bf16_to_float(v.reg.val[i]); + } + } + + explicit FP32Vec16(const FP16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}; + + FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}; + + FP32Vec16 operator*(const FP32Vec16& b) const { + FP32Vec16 result(0.0f); + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + result.reg.val[i] = reg.val[i] * b.reg.val[i]; + } + return result; + } + + FP32Vec16 operator+(const FP32Vec16& b) const { + FP32Vec16 result(0.0f); + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + result.reg.val[i] = reg.val[i] + b.reg.val[i]; + } + return result; + } + + FP32Vec16 operator-(const FP32Vec16& b) const { + FP32Vec16 result(0.0f); + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + result.reg.val[i] = reg.val[i] - b.reg.val[i]; + } + return result; + } + + FP32Vec16 operator/(const FP32Vec16& b) const { + FP32Vec16 result(0.0f); + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + result.reg.val[i] = reg.val[i] / b.reg.val[i]; + } + return result; + } + + FP32Vec16 max(const FP32Vec16& b) const { + FP32Vec16 result(0.0f); + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + result.reg.val[i] = __max(reg.val[i], b.reg.val[i]); + } + return result; + } + + FP32Vec16 min(const FP32Vec16& b) const { + FP32Vec16 result(0.0f); + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + result.reg.val[i] = __min(reg.val[i], b.reg.val[i]); + } + return result; + } + + FP32Vec16 abs() const { + FP32Vec16 result(0.0f); + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + result.reg.val[i] = __abs(reg.val[i]); + } + return result; + } + + float reduce_sum() const { + float result = 0.0f; + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + result += reg.val[i]; + } + return result; + } + + float reduce_max() const { + float result = reg.val[0]; + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + result = __max(reg.val[i], result); + } + return result; + } + + float reduce_min() const { + float result = reg.val[0]; + for (int i = 0; i < VEC_ELEM_NUM; ++i) { + result = __min(reg.val[i], result); + } + return result; + } + + template + float reduce_sub_sum(int idx) { + static_assert(VEC_ELEM_NUM % group_size == 0); + float sum = 0.0; + int start = idx * group_size; + int end = (idx + 1) * group_size; + + for (; (start < VEC_ELEM_NUM) && (start < end); ++start) { + sum += reg.val[start]; + } + + return sum; + } + + void save(void* ptr) const { *reinterpret_cast(ptr) = reg; } +}; + +template +struct VecType { + using vec_type = void; +}; + +template +using vec_t = typename VecType::vec_type; + +template <> +struct VecType { + using vec_type = FP32Vec8; +}; + +template <> +struct VecType { + using vec_type = FP16Vec8; +}; + +template <> +struct VecType { + using vec_type = BF16Vec8; +}; + +template +void storeFP32(float v, T* ptr) { + *ptr = v; +} + +/* +template <> inline void storeFP32(float v, c10::Half *ptr) { + c10::Half __attribute__((__may_alias__)) *v_ptr = + reinterpret_cast(&v); + *ptr = *(v_ptr + 1); +} +*/ + +template <> +inline void storeFP32(float v, c10::Half* ptr) { + uint16_t fp16 = float_to_fp16(v); + *reinterpret_cast(ptr) = fp16; +} + +template <> +inline void storeFP32(float v, c10::BFloat16* ptr) { + c10::BFloat16 __attribute__((__may_alias__))* v_ptr = + reinterpret_cast(&v); + *ptr = *(v_ptr + 1); +} + +inline FP16Vec16::FP16Vec16(const FP32Vec16& v) { + int i = 0; + for (i = 0; i < FP16Vec16::VEC_ELEM_NUM; ++i) { + reg.val[i] = float_to_fp16(v.reg.val[i]); + } +} + +inline FP16Vec8 ::FP16Vec8(const FP32Vec8& v) { + int i = 0; + for (i = 0; i < FP16Vec8::VEC_ELEM_NUM; ++i) { + reg.val[i] = float_to_fp16(v.reg.val[i]); + } +} + +inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) { + acc = acc + a * b; +} + +inline BF16Vec8::BF16Vec8(const FP32Vec8& v) { + int i = 0; + for (i = 0; i < BF16Vec8::VEC_ELEM_NUM; ++i) { + reg.val[i] = float_to_bf16(v.reg.val[i]); + } +} + +inline BF16Vec16::BF16Vec16(const FP32Vec16& v) { + int i = 0; + for (i = 0; i < BF16Vec16::VEC_ELEM_NUM; ++i) { + reg.val[i] = float_to_bf16(v.reg.val[i]); + } +} + +inline void prefetch(const void* addr) { __builtin_prefetch(addr, 0, 3); } + +}; // namespace vec_op diff --git a/csrc/cpu/float_convert.hpp b/csrc/cpu/float_convert.hpp new file mode 100644 index 0000000000000..c792bf131ccdc --- /dev/null +++ b/csrc/cpu/float_convert.hpp @@ -0,0 +1,106 @@ + +static float bf16_to_float(uint16_t bf16) { + uint32_t bits = static_cast(bf16) << 16; + float fp32; + std::memcpy(&fp32, &bits, sizeof(fp32)); + return fp32; +} + +static uint16_t float_to_bf16(float fp32) { + uint32_t bits; + std::memcpy(&bits, &fp32, sizeof(fp32)); + return static_cast(bits >> 16); +} + +/************************************************ + * Copyright (c) 2015 Princeton Vision Group + * Licensed under the MIT license. + * Codes below copied from + * https://github.com/PrincetonVision/marvin/tree/master/tools/tensorIO_matlab + *************************************************/ +static uint16_t float_to_fp16(float fp32) { + uint16_t fp16; + + unsigned x; + unsigned u, remainder, shift, lsb, lsb_s1, lsb_m1; + unsigned sign, exponent, mantissa; + + std::memcpy(&x, &fp32, sizeof(fp32)); + u = (x & 0x7fffffff); + + // Get rid of +NaN/-NaN case first. + if (u > 0x7f800000) { + fp16 = 0x7fffU; + return fp16; + } + + sign = ((x >> 16) & 0x8000); + + // Get rid of +Inf/-Inf, +0/-0. + if (u > 0x477fefff) { + fp16 = sign | 0x7c00U; + return fp16; + } + if (u < 0x33000001) { + fp16 = (sign | 0x0000); + return fp16; + } + + exponent = ((u >> 23) & 0xff); + mantissa = (u & 0x7fffff); + + if (exponent > 0x70) { + shift = 13; + exponent -= 0x70; + } else { + shift = 0x7e - exponent; + exponent = 0; + mantissa |= 0x800000; + } + lsb = (1 << shift); + lsb_s1 = (lsb >> 1); + lsb_m1 = (lsb - 1); + + // Round to nearest even. + remainder = (mantissa & lsb_m1); + mantissa >>= shift; + if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) { + ++mantissa; + if (!(mantissa & 0x3ff)) { + ++exponent; + mantissa = 0; + } + } + + fp16 = (sign | (exponent << 10) | mantissa); + + return fp16; +} + +static float fp16_to_float(uint16_t fp16) { + unsigned sign = ((fp16 >> 15) & 1); + unsigned exponent = ((fp16 >> 10) & 0x1f); + unsigned mantissa = ((fp16 & 0x3ff) << 13); + int temp; + float fp32; + if (exponent == 0x1f) { /* NaN or Inf */ + mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0); + exponent = 0xff; + } else if (!exponent) { /* Denorm or Zero */ + if (mantissa) { + unsigned int msb; + exponent = 0x71; + do { + msb = (mantissa & 0x400000); + mantissa <<= 1; /* normalize */ + --exponent; + } while (!msb); + mantissa &= 0x7fffff; /* 1.mantissa is implicit */ + } + } else { + exponent += 0x70; + } + temp = ((sign << 31) | (exponent << 23) | mantissa); + std::memcpy(&fp32, &temp, sizeof(temp)); + return fp32; +} diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 7dd935d2eb31c..73b97dafcd6eb 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -85,6 +85,7 @@ class CpuArchEnum(enum.Enum): ARM = enum.auto() POWERPC = enum.auto() S390X = enum.auto() + RISCV = enum.auto() OTHER = enum.auto() UNKNOWN = enum.auto() @@ -374,6 +375,8 @@ class Platform: return CpuArchEnum.POWERPC elif machine == "s390x": return CpuArchEnum.S390X + elif machine.startswith("riscv"): + return CpuArchEnum.RISCV return CpuArchEnum.OTHER if machine else CpuArchEnum.UNKNOWN