From 6cc1e7d96dab6b9c344ec87dec6dc9ab07ad5d21 Mon Sep 17 00:00:00 2001 From: "Li, Jiang" Date: Tue, 1 Jul 2025 15:25:03 +0800 Subject: [PATCH] [CPU] Update custom ops for the CPU backend (#20255) Signed-off-by: jiang1.li --- .../scripts/hardware_ci/run-cpu-test.sh | 3 +- cmake/cpu_extension.cmake | 20 + csrc/cpu/sgl-kernels/common.h | 238 +++ csrc/cpu/sgl-kernels/gemm.cpp | 464 ++++++ csrc/cpu/sgl-kernels/gemm.h | 266 ++++ csrc/cpu/sgl-kernels/gemm_fp8.cpp | 530 +++++++ csrc/cpu/sgl-kernels/gemm_int8.cpp | 440 ++++++ csrc/cpu/sgl-kernels/moe.cpp | 1330 +++++++++++++++++ csrc/cpu/sgl-kernels/moe_fp8.cpp | 502 +++++++ csrc/cpu/sgl-kernels/moe_int8.cpp | 769 ++++++++++ csrc/cpu/sgl-kernels/vec.h | 308 ++++ csrc/cpu/shm.cpp | 178 +-- csrc/cpu/torch_bindings.cpp | 43 + docs/getting_started/installation/cpu.md | 1 + .../models/language/generation/test_common.py | 3 +- vllm/_custom_ops.py | 49 + vllm/envs.py | 5 + .../layers/fused_moe/cpu_fused_moe.py | 214 +++ vllm/model_executor/layers/fused_moe/layer.py | 41 +- vllm/model_executor/layers/linear.py | 25 +- vllm/model_executor/layers/utils.py | 25 +- .../layers/vocab_parallel_embedding.py | 2 +- vllm/platforms/cpu.py | 2 + 23 files changed, 5357 insertions(+), 101 deletions(-) create mode 100644 csrc/cpu/sgl-kernels/common.h create mode 100644 csrc/cpu/sgl-kernels/gemm.cpp create mode 100644 csrc/cpu/sgl-kernels/gemm.h create mode 100644 csrc/cpu/sgl-kernels/gemm_fp8.cpp create mode 100644 csrc/cpu/sgl-kernels/gemm_int8.cpp create mode 100644 csrc/cpu/sgl-kernels/moe.cpp create mode 100644 csrc/cpu/sgl-kernels/moe_fp8.cpp create mode 100644 csrc/cpu/sgl-kernels/moe_int8.cpp create mode 100644 csrc/cpu/sgl-kernels/vec.h create mode 100644 vllm/model_executor/layers/fused_moe/cpu_fused_moe.py diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test.sh b/.buildkite/scripts/hardware_ci/run-cpu-test.sh index 8db8c3a05fb3..42506730e868 100644 --- a/.buildkite/scripts/hardware_ci/run-cpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test.sh @@ -51,6 +51,7 @@ function cpu_tests() { pytest -v -s tests/kernels/attention/test_cache.py -m cpu_model pytest -v -s tests/kernels/attention/test_mla_decode_cpu.py -m cpu_model pytest -v -s tests/models/language/generation -m cpu_model + VLLM_CPU_SGL_KERNEL=1 pytest -v -s tests/models/language/generation -m cpu_model pytest -v -s tests/models/language/pooling -m cpu_model pytest -v -s tests/models/multimodal/generation \ --ignore=tests/models/multimodal/generation/test_mllama.py \ @@ -98,4 +99,4 @@ function cpu_tests() { # All of CPU tests are expected to be finished less than 40 mins. export -f cpu_tests -timeout 1h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" +timeout 1.5h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 5cd2c98f2343..264c970ef784 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -96,12 +96,21 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3) list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16") + set(ENABLE_AVX512BF16 ON) else() + set(ENABLE_AVX512BF16 OFF) message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3") endif() else() + set(ENABLE_AVX512BF16 OFF) message(WARNING "Disable AVX512-BF16 ISA support, no avx512_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512BF16=1.") endif() + + find_isa(${CPUINFO} "avx512_vnni" AVX512VNNI_FOUND) + if (AVX512VNNI_FOUND) + list(APPEND CXX_COMPILE_FLAGS "-mavx512vnni") + set(ENABLE_AVX512VNNI ON) + endif() elseif (AVX2_FOUND) list(APPEND CXX_COMPILE_FLAGS "-mavx2") @@ -231,6 +240,17 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) "csrc/cpu/quant.cpp" "csrc/cpu/shm.cpp" ${VLLM_EXT_SRC}) + if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI) + set(VLLM_EXT_SRC + "csrc/cpu/sgl-kernels/gemm.cpp" + "csrc/cpu/sgl-kernels/gemm_int8.cpp" + "csrc/cpu/sgl-kernels/gemm_fp8.cpp" + "csrc/cpu/sgl-kernels/moe.cpp" + "csrc/cpu/sgl-kernels/moe_int8.cpp" + "csrc/cpu/sgl-kernels/moe_fp8.cpp" + ${VLLM_EXT_SRC}) + add_compile_definitions(-DCPU_CAPABILITY_AVX512) + endif() elseif(POWER10_FOUND) set(VLLM_EXT_SRC "csrc/cpu/quant.cpp" diff --git a/csrc/cpu/sgl-kernels/common.h b/csrc/cpu/sgl-kernels/common.h new file mode 100644 index 000000000000..20261c1ef3e8 --- /dev/null +++ b/csrc/cpu/sgl-kernels/common.h @@ -0,0 +1,238 @@ +// Adapted from +// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu + +#pragma once + +#include +#include +#include + +// clang-format off + +#if defined(_OPENMP) +#include +#endif + +namespace { + +// dispatch bool +#define AT_DISPATCH_BOOL(BOOL_V, BOOL_NAME, ...) \ + [&] { \ + if (BOOL_V) { \ + constexpr bool BOOL_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool BOOL_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + +// dispatch: bfloat16, float16, int8_t, fp8_e4m3 +#define CPU_DISPATCH_PACKED_TYPES(TYPE, ...) \ + [&] { \ + switch (TYPE) { \ + case at::ScalarType::BFloat16 : { \ + using packed_t = at::BFloat16; \ + return __VA_ARGS__(); \ + } \ + case at::ScalarType::Half: { \ + using packed_t = at::Half; \ + return __VA_ARGS__(); \ + } \ + case at::ScalarType::Char : { \ + using packed_t = int8_t; \ + return __VA_ARGS__(); \ + } \ + case at::ScalarType::Float8_e4m3fn : { \ + using packed_t = at::Float8_e4m3fn; \ + return __VA_ARGS__(); \ + } \ + default: \ + TORCH_CHECK(false, "Unsupported floating data type.\n"); \ + } \ + }() + +#define UNUSED(x) (void)(x) + +#define CHECK_CPU(x) TORCH_CHECK(x.device().type() == at::kCPU, #x " must be a CPU tensor") + +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_LAST_DIM_CONTIGUOUS(x) \ + TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimention") + +#define CHECK_INPUT(x) \ + CHECK_CPU(x); \ + CHECK_CONTIGUOUS(x) +#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \ + CHECK_CPU(x); \ + CHECK_LAST_DIM_CONTIGUOUS(x) + +#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor") + +#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b) + +// parallel routines +constexpr int GRAIN_SIZE = 1024; + +template ::value, int>::type = 0> +inline T div_up(T x, T y) { return (x + y - 1) / y; } + +template +inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) { +#if 0 + // onednn partition pattern + T& n_my = n_end; + if (nth <= 1 || n == 0) { + n_start = 0; + n_my = n; + } else { + T n1 = div_up(n, nth); + T n2 = n1 - 1; + T T1 = n - n2 * nth; + n_my = ith < T1 ? n1 : n2; + n_start = ith <= T1 ? ith*n1 : T1 * n1 + (ith - T1) * n2; + } + n_end += n_start; +#else + // pytorch aten partition pattern + T n_my = div_up(n, nth); + n_start = ith * n_my; + n_end = std::min(n_start + n_my, n); +#endif +} + +template +inline void parallel_for(int n, const func_t& f) { +#if defined(_OPENMP) +#pragma omp parallel +{ + int nth = omp_get_num_threads(); + int ith = omp_get_thread_num(); + int tbegin, tend; + balance211(n, nth, ith, tbegin, tend); + f(tbegin, tend); +} +#else + f(0, n); +#endif +} + +// for 1d parallel, use `actual_nth` +// for 2d parallel, use even nths, e.g. 43->42 +int inline adjust_num_threads(int m) { + int actual_nth = at::get_num_threads(); + if (m == 1) { + return actual_nth; + } + return std::max(1, (actual_nth >> 1) * 2); +} + +template +inline void parallel_2d(int m, int n, const func_t& f) { + + // make sure we have even num_threads + int nth = adjust_num_threads(m); + + // [NOTE] thread blocking: + // + // 1) prefer square block per thread + // 2) use even number of CPU cores + // 3) use all `num_threads` cores + // + // we have: + // TM * TN = T + // BM / TM = BN / TN + // then: + // TM = ((BM / BN) * T) ^ 0.5 + // + float r = float(m) / n; + int nth_m = std::ceil(std::sqrt(r * nth)); + int nth_n = 1; + for (; nth_m > 0; --nth_m) { + nth_n = nth / nth_m; + if (nth_m * nth_n == nth) { + break; + } + } + +#if defined(_OPENMP) +#pragma omp parallel num_threads(nth) +{ + int ith = omp_get_thread_num(); + int ith_m = ith / nth_n; + int ith_n = ith % nth_n; + + int thread_block_m = div_up(m, nth_m); + int thread_block_n = div_up(n, nth_n); + + int begin_m = ith_m * thread_block_m; + int end_m = std::min(m, begin_m + thread_block_m); + int begin_n = ith_n * thread_block_n; + int end_n = std::min(n, begin_n + thread_block_n); + + f(begin_m, end_m, begin_n, end_n); +} +#else + f(0, m, 0, n); +#endif +} + +template +int get_cache_blocks(int BLOCK_SIZE, int K) { + // L2 2MB and ratio of 50% + const int L2_size = 2048 * 1024 >> 1; + return std::max(1, int(L2_size / (BLOCK_SIZE * K * sizeof(T)))); +} + +// data indexing for dimension collapse +template +inline T data_index_init(T offset) { + return offset; +} + +template +inline T data_index_init(T offset, T& x, const T& X, Args&&... args) { + offset = data_index_init(offset, std::forward(args)...); + x = offset % X; + return offset / X; +} + +inline bool data_index_step() { + return true; +} + +template +inline bool data_index_step(T& x, const T& X, Args&&... args) { + if (data_index_step(std::forward(args)...)) { + x = ((x + 1) == X) ? 0 : (x + 1); + return x == 0; + } + return false; +} + +// forced unroll for perf critical path + +#if __has_attribute(always_inline) +#define ALWAYS_INLINE __attribute__((__always_inline__)) inline +#else +#define ALWAYS_INLINE inline +#endif + +template +struct Unroll { + template + ALWAYS_INLINE void operator()(const Func& f, Args... args) const { + Unroll{}(f, args...); + f(std::integral_constant{}, args...); + } +}; + +template <> +struct Unroll<1> { + template + ALWAYS_INLINE void operator()(const Func& f, Args... args) const { + f(std::integral_constant{}, args...); + } +}; + +} // anonymous namespace diff --git a/csrc/cpu/sgl-kernels/gemm.cpp b/csrc/cpu/sgl-kernels/gemm.cpp new file mode 100644 index 000000000000..c122d07185dd --- /dev/null +++ b/csrc/cpu/sgl-kernels/gemm.cpp @@ -0,0 +1,464 @@ +// Adapted from +// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu + +#include "common.h" +#include "vec.h" +#include "gemm.h" + +// clang-format off + +namespace { + +// packed layout: +// quants {N, K} int8_t +// comp {N} int32_t +template +inline void s8s8_compensation(int8_t* __restrict__ packed, int K) { +#if defined(CPU_CAPABILITY_AVX512) + constexpr int COLS = BLOCK_N / 16; + __m512i vcomp[COLS]; + + for (int col = 0; col < COLS; ++col) { + vcomp[col] = _mm512_setzero_si512(); + } + + const int64_t offset = BLOCK_N * K; + const __m512i off = _mm512_set1_epi8(static_cast(0x80)); + for (int k = 0; k < K / 4; ++k) { + for (int col = 0; col < COLS; ++col) { + __m512i vb = _mm512_loadu_si512((const __m512i *)(packed + k * BLOCK_N * 4 + col * 64)); + vcomp[col] = _mm512_dpbusd_epi32(vcomp[col], off, vb); + } + } + + for (int col = 0; col < COLS; ++col) { + _mm512_storeu_si512((__m512i *)(packed + offset + col * 64), vcomp[col]); + } +#else + TORCH_CHECK(false, "s8s8_compensation not implemented!"); +#endif +} + +// convert to vnni format +// from [N, K] to [K/2, N, 2] for bfloat16 and float16 +template +inline void pack_vnni(packed_t* __restrict__ packed, const packed_t* __restrict__ weight, int N, int K) { + const int VNNI_BLK = 2; + for (int n = 0; n < N; ++n) { + for (int k = 0; k < K / VNNI_BLK; ++k) { + for (int d = 0; d < VNNI_BLK; ++d) { + packed[k * N * VNNI_BLK + n * VNNI_BLK + d] = weight[n * K + k * VNNI_BLK + d]; + } + } + } +} + +template <> +inline void pack_vnni(int8_t* __restrict__ packed, const int8_t* __restrict__ weight, int N, int K) { + constexpr int BLOCK_N = block_size_n(); + TORCH_CHECK(N == BLOCK_N); + + const int VNNI_BLK = 4; + for (int n = 0; n < N; ++n) { + for (int k = 0; k < K / VNNI_BLK; ++k) { + for (int d = 0; d < VNNI_BLK; ++d) { + packed[k * N * VNNI_BLK + n * VNNI_BLK + d] = weight[n * K + k * VNNI_BLK + d]; + } + } + } + s8s8_compensation(packed, K); +} + +template +inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec data0 = fVec::loadu(input + d); + fVec data1 = fVec::loadu(input + d + fVec::size()); + bVec out_vec = convert_from_float_ext(data0, data1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d]); + } +} + +template +inline void copy_add_stub(scalar_t* __restrict__ out, const float* __restrict__ input, const float* __restrict__ bias, int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec data0 = fVec::loadu(input + d) + fVec::loadu(bias + d); + fVec data1 = fVec::loadu(input + d + fVec::size()) + fVec::loadu(bias + d + fVec::size()); + bVec out_vec = convert_from_float_ext(data0, data1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d] + bias[d]); + } +} + +template +struct tinygemm_kernel_nn { + static inline void apply( + const scalar_t* __restrict__ A, const scalar_t* __restrict__ B, scalar_t* __restrict__ C, + const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); + } +}; + +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel_nn { + static inline void apply( + const at::BFloat16* __restrict__ A, const at::BFloat16* __restrict__ B, at::BFloat16* __restrict__ C, + const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + + // prefetch distance + constexpr int PREFETCH_SIZE_K = 0; + + __m512bh va; + __m512bh vb[COLS]; + __m512 vc[ROWS * COLS]; + + auto loadc = [&](auto i) { + constexpr int col = i % COLS; + if constexpr (has_bias) { + vc[i] = _mm512_loadu_ps(bias + col * 16); + } else { + vc[i] = _mm512_set1_ps(0.f); + } + }; + Unroll{}(loadc); + + const int64_t K2 = K >> 1; + const int64_t lda2 = lda >> 1; + const int64_t ldb2 = ldb; // ldb * 2 >> 1; + const float* a_ptr = reinterpret_cast(A); + const float* b_ptr = reinterpret_cast(B); + + auto compute = [&](auto i, int64_t k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k])); + } + if constexpr (row == 0) { + vb[col] = (__m512bh)(_mm512_loadu_si512(b_ptr + k * ldb2 + col * 16)); + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0); + } + } + vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]); + }; + for (int64_t k = 0; k < K2; ++k) { + Unroll{}(compute, k); + } + + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + // for COLS = 2, 4 use 512bit store + // for COLS = 1, 3 use 256bit store + if constexpr (COLS % 2 == 0) { + if constexpr (col % 2 == 0) { + _mm512_storeu_si512( + reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), + (__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col]))); + } + } else { + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(C + row * ldc + col * 16), + (__m256i)(_mm512_cvtneps_pbh(vc[i]))); + } + }; + Unroll{}(storec); + } +}; +#endif + +#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_nn::apply( \ + A + mb_start * lda, B + nb_start * 2, C + mb_start * ldc + nb_start, \ + has_bias ? bias + nb_start : nullptr, K, lda, ldb, ldc); + +template +struct brgemm { + static inline void apply( + const scalar_t* __restrict__ A, const scalar_t* __restrict__ B, scalar_t* __restrict__ C, + float* __restrict__ Ctmp, const float* __restrict__ bias, + int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + + constexpr int BLOCK_N = block_size_n(); + at::native::cpublas::brgemm( + M, N, K, lda, ldb, BLOCK_N, /* add_C */false, + A, B, Ctmp); + + // copy from Ctmp to C + for (int64_t m = 0; m < M; ++m) { + if constexpr (has_bias) { + copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N); + } else { + copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N); + } + } + } +}; + +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ B, + scalar_t* __restrict__ C, + float* __restrict__ Ctmp, + const float* __restrict__ bias, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + bool brg) { + + if (brg) { + brgemm::apply( + A, B, C, Ctmp, bias, + M, N, K, lda, ldb, ldc); + return; + } + + // pattern: 1-4-16 + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 64; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch(mb_size << 4 | nb_size >> 4) { + // mb_size = 1 + case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break; + case 0x14: LAUNCH_TINYGEMM_KERNEL_NN(1, 64); break; + // mb_size = 2 + case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break; + case 0x24: LAUNCH_TINYGEMM_KERNEL_NN(2, 64); break; + // mb_size = 3 + case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break; + case 0x34: LAUNCH_TINYGEMM_KERNEL_NN(3, 64); break; + // mb_size = 4 + case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break; + case 0x44: LAUNCH_TINYGEMM_KERNEL_NN(4, 64); break; + default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + } + } + } +} + +template +void weight_packed_linear_kernel_impl( + scalar_t* __restrict__ out, + const scalar_t* __restrict__ mat1, + const scalar_t* __restrict__ mat2, + const float* __restrict__ bias, + int64_t M, + int64_t N, + int64_t K, + int64_t mat1_strideM, + int64_t out_strideM) { + + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + + // use avx512-bf16 when a) M is small; b) dtype is bfloat16, otherwise use amx + const bool use_brgemm = (M > 4) || (!std::is_same_v); + + // l2 cache block for n + int64_t cache_blocks_nb = get_cache_blocks(BLOCK_N, K); + + // parallel on [MB, NB] + AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] { + parallel_2d(MB, NB, [&](int64_t begin_mb, int64_t end_mb, int64_t begin_nb, int64_t end_nb) { + + // for brgemm, use float32 for accumulate + alignas(64) float Ctmp[BLOCK_M * BLOCK_N]; + + for (int64_t nbb = begin_nb; nbb < end_nb; nbb += cache_blocks_nb) { + for (int64_t mb = begin_mb; mb < end_mb; ++mb) { + for (int64_t nb = nbb; nb < std::min(nbb + cache_blocks_nb, end_nb); ++nb) { + + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(M - mb_start, BLOCK_M); + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(N - nb_start, BLOCK_N); + + tinygemm_kernel( + /* A */ mat1 + mb_start * mat1_strideM, + /* B */ mat2 + nb_start * K /* nb * BLOCK_N * K */, + /* C */ out + mb_start * out_strideM + nb_start, + /* Ctmp*/ Ctmp, + /* bias*/ bias + nb_start, + /* M */ mb_size, + /* N */ nb_size, + /* K */ K, + /* lda */ mat1_strideM, + /* ldb */ nb_size, + /* ldc */ out_strideM, + /* brg */ use_brgemm); + }}} + + if (use_brgemm) { + at::native::cpublas::brgemm_release(); + } + }); + }); +} + +} // anonymous namespace + +// tinygemm interface +template +void tinygemm_kernel(const scalar_t* __restrict__ A, const scalar_t* __restrict__ B, scalar_t* __restrict__ C, + float* __restrict__ Ctmp, int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool brg) { + tinygemm_kernel(A, B, C, Ctmp, nullptr, M, N, K, lda, ldb, ldc, brg); +} + +#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \ + template void tinygemm_kernel( \ + const TYPE* __restrict__ A, const TYPE* __restrict__ B, TYPE* __restrict__ C, \ + float* __restrict__ Ctmp, int64_t M, int64_t N, int64_t K, int64_t lda, \ + int64_t ldb, int64_t ldc, bool brg) + +INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16); +INSTANTIATE_TINYGEMM_TEMPLATE(at::Half); + +at::Tensor convert_weight_packed(at::Tensor& weight) { + // for 3d moe weights + // weight : [E, OC, IC] + // w1 : [E, 2N, K] + // w2 : [E, K, N] + CHECK_INPUT(weight); + + const int64_t ndim = weight.ndimension(); + TORCH_CHECK(ndim == 2 || ndim == 3, "expect weight to be 2d or 3d, got ", ndim, "d tensor."); + const auto st = weight.scalar_type(); + const int64_t E = ndim == 3 ? weight.size(0) : 1; + const int64_t OC = ndim == 3 ? weight.size(1) : weight.size(0); + const int64_t IC = ndim == 3 ? weight.size(2) : weight.size(1); + + // we handle 2 TILE_N at a time. + TORCH_CHECK(OC % TILE_N == 0, "invalid weight out features ", OC); + TORCH_CHECK(IC % TILE_K == 0, "invalid weight input features ", IC); + + constexpr int64_t BLOCK_N = block_size_n(); + const int64_t NB = div_up(OC, BLOCK_N); + + // use phony sizes here [E, OC, IC], for each [E], [OC, IC] -> [IC / 2, OC, 2] + auto packed_weight = at::empty({}, weight.options()); + const int64_t stride = OC * IC; + + TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf || st == at::kChar || st == at::kFloat8_e4m3fn, + "expect weight to be bfloat16, float16, int8 or fp8_e4m3."); + + CPU_DISPATCH_PACKED_TYPES(st, [&] { + // adjust most inner dimension size + const int packed_row_size = get_row_size(IC); + auto sizes = weight.sizes().vec(); + sizes[ndim - 1] = packed_row_size; + packed_weight.resize_(sizes); + + const packed_t* w_data = weight.data_ptr(); + packed_t* packed_data = packed_weight.data_ptr(); + + // parallel on {E, NB} + at::parallel_for(0, E * NB, 0, [&](int64_t begin, int64_t end) { + int64_t e{0}, nb{0}; + data_index_init(begin, e, E, nb, NB); + + for (int64_t i = begin; i < end; ++i) { + UNUSED(i); + + int64_t n = nb * BLOCK_N; + int64_t n_size = std::min(BLOCK_N, OC - n); + pack_vnni( + packed_data + e * OC * packed_row_size + n * packed_row_size, + w_data + e * stride + n * IC, + n_size, + IC); + + // move to the next index + data_index_step(e, E, nb, NB); + } + }); + }); + return packed_weight; +} + +// mat1 : [M, K] +// mat2 : [N, K] +// bias : [N] +// out : [M, N] +// +at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, + const std::optional& bias, bool is_vnni) { + RECORD_FUNCTION( + "sgl-kernel::weight_packed_linear", std::vector({mat1, mat2, bias})); + + auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2); + + CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1); + CHECK_INPUT(mat2); + + int64_t M = mat1.size(0); + int64_t N = mat2.size(0); + int64_t K = mat2.size(1); + CHECK_EQ(mat1.size(1), K); + CHECK_DIM(2, mat1); + CHECK_DIM(2, mat2); + + auto out = at::empty({M, N}, mat1.options()); + + // strides + int64_t mat1_strideM = mat1.stride(0); + int64_t out_strideM = out.stride(0); + + const bool has_bias = bias.has_value(); + const float* bias_data = nullptr; + if (has_bias) { + CHECK_EQ(bias.value().size(0), N); + bias_data = bias.value().data_ptr(); + } + + AT_DISPATCH_REDUCED_FLOATING_TYPES(mat1.scalar_type(), "weight_packed_linear_kernel_impl", [&] { + weight_packed_linear_kernel_impl( + out.data_ptr(), + mat1.data_ptr(), + packed_w.data_ptr(), + bias_data, + M, + N, + K, + mat1_strideM, + out_strideM); + }); + + return out; +} diff --git a/csrc/cpu/sgl-kernels/gemm.h b/csrc/cpu/sgl-kernels/gemm.h new file mode 100644 index 000000000000..afae19721ae9 --- /dev/null +++ b/csrc/cpu/sgl-kernels/gemm.h @@ -0,0 +1,266 @@ +#pragma once + +#include + +// clang-format off + +// amx-bf16 +#define TILE_M 16 +#define TILE_N 16 +#define TILE_K 32 + +// block size for AMX gemm +constexpr int block_size_m() { return 2 * TILE_M; } +constexpr int block_size_n() { return 2 * TILE_N; } + +// define threshold using brgemm (intel AMX) +template inline bool can_use_brgemm(int M); +template <> inline bool can_use_brgemm(int M) { return M > 4; } +template <> inline bool can_use_brgemm(int M) { return true; } +// TODO: add u8s8 brgemm, this requires PyTorch 2.7 +template <> inline bool can_use_brgemm(int M) { return false; } +template <> inline bool can_use_brgemm(int M) { return M > 4; } +template <> inline bool can_use_brgemm(int M) { return M > 4; } + +// work around compiler internal error +#define BLOCK_K 128 // 4 * TILE_K + +// adjust leading dimension size for K +template +inline int64_t get_row_size(int64_t K) { + return K; +} + +template <> +inline int64_t get_row_size(int64_t K) { + return K + sizeof(int32_t); +} + +inline int64_t get_row_size(int64_t K, bool use_int8_w8a8) { + return use_int8_w8a8 ? K + sizeof(int32_t) : K; +} + +// pack weight to vnni format +at::Tensor convert_weight_packed(at::Tensor& weight); + +// moe implementations for int8 w8a8 +template +void fused_experts_int8_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic1, + scalar_t* __restrict__ ic2, + uint8_t* __restrict__ A_tmp, + float* __restrict__ C_tmp, + uint8_t* __restrict__ Aq_tmp, + float* __restrict__ As_tmp, + const scalar_t* __restrict__ input, + const int8_t* __restrict__ packed_w1, + const int8_t* __restrict__ packed_w2, + const float* __restrict__ w1s, + const float* __restrict__ w2s, + const float* __restrict__ topk_weights, + const int32_t* __restrict__ sorted_ids, + const int32_t* __restrict__ expert_ids, + const int32_t* __restrict__ offsets, + int64_t M, + int64_t N, + int64_t K, + int64_t E, + int64_t topk, + int64_t num_tokens_post_pad); + +// moe implementations for fp8 w8a16 +template +void fused_experts_fp8_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic0, + scalar_t* __restrict__ ic1, + scalar_t* __restrict__ ic2, + scalar_t* __restrict__ A_tmp, + scalar_t* __restrict__ B_tmp, + float* __restrict__ C_tmp, + const scalar_t* __restrict__ input, + const at::Float8_e4m3fn* __restrict__ packed_w1, + const at::Float8_e4m3fn* __restrict__ packed_w2, + const float* __restrict__ w1s, + const float* __restrict__ w2s, + int64_t block_size_N, + int64_t block_size_K, + const float* __restrict__ topk_weights, + const int32_t* __restrict__ sorted_ids, + const int32_t* __restrict__ expert_ids, + const int32_t* __restrict__ offsets, + int64_t M, + int64_t N, + int64_t K, + int64_t E, + int64_t topk, + int64_t num_tokens_post_pad); + +// moe implementations for int4 w4a16 +template +void fused_experts_int4_w4a16_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic0, + scalar_t* __restrict__ ic1, + scalar_t* __restrict__ ic2, + scalar_t* __restrict__ A_tmp, + scalar_t* __restrict__ B_tmp, + float* __restrict__ C_tmp, + const scalar_t* __restrict__ input, + const at::quint4x2* __restrict__ packed_w1, + const at::quint4x2* __restrict__ packed_w2, + const uint8_t* __restrict__ w1z, + const uint8_t* __restrict__ w2z, + const scalar_t* __restrict__ w1s, + const scalar_t* __restrict__ w2s, + int group_size, + const float* __restrict__ topk_weights, + const int32_t* __restrict__ sorted_ids, + const int32_t* __restrict__ expert_ids, + const int32_t* __restrict__ offsets, + int64_t M, + int64_t N, + int64_t K, + int64_t E, + int64_t topk, + int64_t num_tokens_post_pad); + +// shared expert implememntation for int8 w8a8 +template +void shared_expert_int8_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic1, + float* __restrict__ C_tmp, + uint8_t* __restrict__ Aq_tmp, + float* __restrict__ As_tmp, + const scalar_t* __restrict__ input, + const int8_t* __restrict__ packed_w1, + const int8_t* __restrict__ packed_w2, + const float* __restrict__ w1s, + const float* __restrict__ w2s, + const scalar_t* __restrict__ fused_experts_out, + float routed_scaling_factor, + int64_t M, + int64_t N, + int64_t K); + +template +void shared_expert_fp8_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic0, + scalar_t* __restrict__ ic1, + scalar_t* __restrict__ B_tmp, + float* __restrict__ C_tmp, + const scalar_t* __restrict__ input, + const at::Float8_e4m3fn* __restrict__ packed_w1, + const at::Float8_e4m3fn* __restrict__ packed_w2, + const float* __restrict__ w1s, + const float* __restrict__ w2s, + int64_t block_size_N, + int64_t block_size_K, + const scalar_t* __restrict__ fused_experts_out, + float routed_scaling_factor, + int64_t M, + int64_t N, + int64_t K); + +// tinygemm interface +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ B, + scalar_t* __restrict__ C, + float* __restrict__ Ctmp, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + bool brg); + +template +void tinygemm_kernel( + const uint8_t* __restrict__ A, + const int8_t* __restrict__ B, + scalar_t* __restrict__ C, + int32_t* __restrict__ Ctmp, + const float* __restrict__ As, + const float* __restrict__ Bs, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + bool brg); + +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const at::Float8_e4m3fn* __restrict__ B, + scalar_t* __restrict__ C, + scalar_t* __restrict__ Btmp, + float* __restrict__ Ctmp, + const float* __restrict__ scale, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + bool brg, + int64_t block_size_K); + +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const at::quint4x2* __restrict__ B, + scalar_t* __restrict__ C, + const uint8_t* __restrict__ Bz, + const scalar_t* __restrict__ Bs, + scalar_t* __restrict__ Btmp, + float* __restrict__ Ctmp, + int64_t M, + int64_t N, + int64_t K, + int group_size, + int64_t lda, + int64_t ldb, + int64_t ldc, + int64_t strideBz, + int64_t strideBs, + bool brg); + +// TODO: debug print, remove me later +inline void print_16x32i(const __m512i x) { + int32_t a[16]; + _mm512_storeu_si512((__m512i *)a, x); + + for (int i = 0; i < 16; i++){ + std::cout << a[i] << " "; + } + std::cout << std::endl; +} + +inline void print_16x32(const __m512 x) { + float a[16]; + _mm512_storeu_ps((__m512 *)a, x); + + for (int i = 0; i < 16; i++){ + std::cout << a[i] << " "; + } + std::cout << std::endl; +} + + +inline void print_32x8u(const __m256i x) { + uint8_t a[32]; + _mm256_storeu_si256((__m256i *)a, x); + + for (int i = 0; i < 32; ++i) { + std::cout << int32_t(a[i]) << " "; + } + std::cout << std::endl; +} diff --git a/csrc/cpu/sgl-kernels/gemm_fp8.cpp b/csrc/cpu/sgl-kernels/gemm_fp8.cpp new file mode 100644 index 000000000000..b5f2f07bad62 --- /dev/null +++ b/csrc/cpu/sgl-kernels/gemm_fp8.cpp @@ -0,0 +1,530 @@ +// Adapted from +// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu + +#include "common.h" +#include "vec.h" +#include "gemm.h" + +// clang-format off + +// we use 4x32 for BLOCK_M +#define BLOCK_SIZE_M_SCALE 4 + +namespace { + +template +inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec data0 = fVec::loadu(input + d); + fVec data1 = fVec::loadu(input + d + fVec::size()); + bVec out_vec = convert_from_float_ext(data0, data1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d]); + } +} + +template +inline void copy_add_stub(scalar_t* __restrict__ out, const float* __restrict__ input, const float* __restrict__ bias, int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec data0 = fVec::loadu(input + d) + fVec::loadu(bias + d); + fVec data1 = fVec::loadu(input + d + fVec::size()) + fVec::loadu(bias + d + fVec::size()); + bVec out_vec = convert_from_float_ext(data0, data1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d] + bias[d]); + } +} + +inline void unpack_B( + at::BFloat16* __restrict__ Btmp, + const at::Float8_e4m3fn* __restrict__ packed_B, + int N, + int K, + int ldb, + int ldb_tmp, + float scale) { +#if defined(CPU_CAPABILITY_AVX512) + // [K/2, N, 2] + const int K2 = K >> 1; + const int ldb2 = ldb; // ldb * 2 >> 1; + const uint16_t* b_ptr = reinterpret_cast(packed_B); + const __m512 vd = _mm512_set1_ps(scale); + + constexpr int BLOCK_N = block_size_n(); + static_assert(BLOCK_N == 32); + + // prefetch distance + constexpr int PREFETCH_SIZE_K = 64; + +#pragma GCC unroll 4 + for (int k = 0; k < K2; ++k) { + __m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2); + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2, _MM_HINT_T0); + } + + __m256i b8_0 = _mm512_extracti32x8_epi32(b8, 0); + __m256i b8_1 = _mm512_extracti32x8_epi32(b8, 1); + + __m512bh bf16_0 = CVT_FP8_TO_BF16(b8_0); + __m512bh bf16_1 = CVT_FP8_TO_BF16(b8_1); + + // Apply scale + __m512 f0_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 0)); + __m512 f0_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 1)); + __m512 f1_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 0)); + __m512 f1_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 1)); + + f0_lo = _mm512_mul_ps(f0_lo, vd); + f0_hi = _mm512_mul_ps(f0_hi, vd); + f1_lo = _mm512_mul_ps(f1_lo, vd); + f1_hi = _mm512_mul_ps(f1_hi, vd); + + bf16_0 = _mm512_cvtne2ps_pbh(f0_hi, f0_lo); + bf16_1 = _mm512_cvtne2ps_pbh(f1_hi, f1_lo); + + _mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 0, (__m512i)bf16_0); + _mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 32, (__m512i)bf16_1); + } +#else + TORCH_CHECK(false, "unpack_B: scalar path not implemented!"); +#endif +} + +template +struct tinygemm_kernel_nn { + static inline void apply( + const scalar_t* __restrict__ A, const packed_t* __restrict__ B, scalar_t* __restrict__ C, + const float* __restrict__ bias, const float* __restrict__ scale, int K, int lda, int ldb, int ldc, int64_t block_size_K) { + TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); + } +}; + +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel_nn { + static inline void apply( + const at::BFloat16* __restrict__ A, const at::Float8_e4m3fn* __restrict__ B, at::BFloat16* __restrict__ C, + const float* __restrict__ bias, const float* __restrict__ scale, int K, int lda, int ldb, int ldc, int64_t block_size_K) { + + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + + const int KB = div_up(K, BLOCK_K); + + // prefetch distance + constexpr int PREFETCH_SIZE_K = 64; + constexpr int PREFETCH_SIZE_KB = 1; + + __m512bh va; + __m512bh vb[COLS]; + __m512 vc[ROWS * COLS]; + __m512 vsum[ROWS * COLS]; + + // block quant scale + __m512 vscale; + + auto loadc = [&](auto i) { + constexpr int col = i % COLS; + if constexpr (has_bias) { + vc[i] = _mm512_loadu_ps(bias + col * 16); + } else { + vc[i] = _mm512_setzero_ps(); + } + }; + Unroll{}(loadc); + + const int lda2 = lda >> 1; + const int ldb2 = ldb; // ldb * 2 >> 1; + const float* a_ptr = reinterpret_cast(A); + const uint16_t* b_ptr = reinterpret_cast(B); + + auto compute = [&](auto i, int k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k])); + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(a_ptr + row * lda2 + k + PREFETCH_SIZE_K, _MM_HINT_T0); + } + } + if constexpr (row == 0) { + if constexpr (col % 2 == 0) { + __m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2 + col * 16); + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0); + } + vb[col + 0] = CVT_FP8_TO_BF16(_mm512_extracti32x8_epi32(b8, 0)); + vb[col + 1] = CVT_FP8_TO_BF16(_mm512_extracti32x8_epi32(b8, 1)); + } + } + vsum[i] = _mm512_dpbf16_ps(vsum[i], va, vb[col]); + }; + + constexpr int BLOCK_K2 = BLOCK_K >> 1; + for (int kb = 0; kb < KB; ++kb) { + int kb_start = kb * BLOCK_K2; + int kb_end = std::min(K, kb_start + BLOCK_K2); + // 1. load scale vector + vscale = _mm512_set1_ps(scale[kb]); + if constexpr (PREFETCH_SIZE_KB > 0) { + _mm_prefetch(scale + kb + PREFETCH_SIZE_KB, _MM_HINT_T0); + } + // 2. zero vsum for each block + Unroll{}([&](auto i) { + vsum[i] = _mm512_setzero_ps(); + }); + // 3. accumulate across each block + for (int k = kb_start; k < kb_end; ++k) { + Unroll{}(compute, k); + } + // 4. apply scale + Unroll{}([&](auto i) { + vc[i] = _mm512_fmadd_ps(vsum[i], vscale, vc[i]); + }); + } + + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + // for COLS = 2,4 use 512bit store + if constexpr (col % 2 == 0) { + _mm512_storeu_si512( + reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), + (__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col]))); + } + }; + Unroll{}(storec); + } +}; +#endif + +#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_nn::apply( \ + A + mb_start * lda, B + nb_start * 2, C + mb_start * ldc + nb_start, \ + has_bias ? bias + nb_start : nullptr, scale, K, lda, ldb, ldc, block_size_K); + +template +struct brgemm { + static inline void apply( + const scalar_t* __restrict__ A, + const packed_t* __restrict__ B, + scalar_t* __restrict__ C, + scalar_t* __restrict__ Btmp, + float* __restrict__ Ctmp, + const float* __restrict__ bias, + const float* __restrict__ scale, + int M, + int N, + int K, + int lda, + int ldb, + int ldc) { + TORCH_CHECK(false, "struct brgemm: primary template not implemented!"); + } +}; + +template +struct brgemm { + static inline void apply( + const at::BFloat16* __restrict__ A, + const at::Float8_e4m3fn* __restrict__ B, + at::BFloat16* __restrict__ C, + at::BFloat16* __restrict__ Btmp, + float* __restrict__ Ctmp, + const float* __restrict__ bias, + const float* __restrict__ scale, + int M, + int N, + int K, + int lda, + int ldb, + int ldc) { + + constexpr int BLOCK_N = block_size_n(); + + // [K, BLOCK_N] -> [K / 2, BLOCK_N * 2] + const int ldb_tmp = BLOCK_N; + + for (int k = 0; k < K; k += BLOCK_K) { + int kb_size = std::min(BLOCK_K, K - k); + + int idx = k >> 7; // k / BLOCK_K where BLOCK_K = 128 + unpack_B(Btmp + k * ldb_tmp, B + k * ldb, N, kb_size, ldb, ldb_tmp, scale[idx]); + } + + at::native::cpublas::brgemm( + M, N, K, lda, ldb_tmp, BLOCK_N, /* add_C */ false, A, Btmp, Ctmp); + + // copy from Ctmp to C + for (int m = 0; m < M; ++m) { + if constexpr (has_bias) { + copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N); + } else { + copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N); + } + } + } +}; + +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const at::Float8_e4m3fn* __restrict__ B, + scalar_t* __restrict__ C, + scalar_t* __restrict__ Btmp, + float* __restrict__ Ctmp, + const float* __restrict__ scale, + const float* __restrict__ bias, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + bool brg, + int64_t block_size_K) { + + if (brg) { + brgemm::apply( + A, B, C, Btmp, Ctmp, bias, scale, M, N, K, lda, ldb, ldc); + return; + } + + // pattern: 1-4-16 + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 64; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch(mb_size << 4 | nb_size >> 4) { + case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break; + case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break; + case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break; + case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break; + default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + } + } + } +} + +template +void fp8_scaled_mm_kernel_impl( + scalar_t* __restrict__ out, + const scalar_t* __restrict__ mat1, + const at::Float8_e4m3fn* __restrict__ mat2, + const float* __restrict__ scales2, + const float* __restrict__ bias, + scalar_t* __restrict__ buffer, + int64_t M, + int64_t N, + int64_t K, + int64_t mat1_strideM, + int64_t out_strideM, + int64_t block_size_N, + int64_t block_size_K, + int64_t buffer_size_per_thread) { + + constexpr int64_t BLOCK_M = block_size_m() * BLOCK_SIZE_M_SCALE; + constexpr int64_t BLOCK_N = block_size_n(); + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + + const int64_t scale_size_K = div_up(K, block_size_K); + const int64_t blocks_n_per_group = block_size_N / BLOCK_N; + + const bool use_brgemm = can_use_brgemm(M); + + // parallel on [MB, NB] + AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] { + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + int64_t mb{0}, nb{0}; + data_index_init(begin, mb, MB, nb, NB); + + int tid = at::get_thread_num(); + scalar_t* __restrict__ Btmp = buffer + tid * buffer_size_per_thread; + float* __restrict__ Ctmp = (float*)((void*)(Btmp + BLOCK_N * K)); + + for (int64_t i = begin; i < end; ++i) { + UNUSED(i); + const float* scale_ptr = scales2 + (nb / blocks_n_per_group) * scale_size_K; + + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(M - mb_start, BLOCK_M); + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(N - nb_start, BLOCK_N); + + tinygemm_kernel( + /* A */ mat1 + mb_start * mat1_strideM, + /* B */ mat2 + nb_start * K, // nb * BLOCK_N * K + /* C */ out + mb_start * out_strideM + nb_start, + /* Btmp */ Btmp, + /* Ctmp */ Ctmp, + /* scale */ scale_ptr, + /* bias */ bias + nb_start, + /* M */ mb_size, + /* N */ nb_size, + /* K */ K, + /* lda */ mat1_strideM, + /* ldb */ nb_size, + /* ldc */ out_strideM, + /* brg */ use_brgemm, + /* block_size_K */ block_size_K); + + // move to the next index + data_index_step(mb, MB, nb, NB); + } + + if (use_brgemm) { + at::native::cpublas::brgemm_release(); + } + }); + }); +} + +} // anonymous namespace + +// tinygemm interface +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const at::Float8_e4m3fn* __restrict__ B, + scalar_t* __restrict__ C, + scalar_t* __restrict__ Btmp, + float* __restrict__ Ctmp, + const float* __restrict__ scale, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + bool brg, + int64_t block_size_K) { + tinygemm_kernel(A, B, C, Btmp, Ctmp, scale, nullptr, M, N, K, lda, ldb, ldc, brg, block_size_K); +} + +#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \ + template void tinygemm_kernel( \ + const TYPE* __restrict__ A, \ + const at::Float8_e4m3fn* __restrict__ B, \ + TYPE* __restrict__ C, \ + TYPE* __restrict__ Btmp, \ + float* __restrict__ Ctmp, \ + const float* __restrict__ scale, \ + int64_t M, \ + int64_t N, \ + int64_t K, \ + int64_t lda, \ + int64_t ldb, \ + int64_t ldc, \ + bool brg, \ + int64_t block_size_K) + +INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16); +INSTANTIATE_TINYGEMM_TEMPLATE(at::Half); + +at::Tensor fp8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2, at::Tensor& scales2, + std::vector block_size, std::optional& bias, + at::ScalarType out_dtype, bool is_vnni) { + RECORD_FUNCTION("sgl-kernel::fp8_scaled_mm_cpu", std::vector({mat1, mat2, scales2, block_size, bias})); + + auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2); + + CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1); + CHECK_INPUT(mat2); + CHECK_INPUT(scales2); + TORCH_CHECK(scales2.scalar_type() == at::kFloat, + "fp8_scaled_mm_cpu: expect scales2 to be float32."); + + int64_t M = mat1.size(0); + int64_t N = mat2.size(0); + int64_t K = mat2.size(1); + + CHECK_EQ(mat1.size(1), K); + CHECK_DIM(2, mat1); + CHECK_DIM(2, mat2); + + TORCH_CHECK(block_size.size() == 2, + "fp8_scaled_mm_cpu: expect block_size.size() to be 2."); + + int64_t block_size_N = block_size[0]; + int64_t block_size_K = block_size[1]; + + constexpr int64_t BLOCK_M = block_size_m() * BLOCK_SIZE_M_SCALE; + constexpr int64_t BLOCK_N = block_size_n(); + TORCH_CHECK(block_size_N % BLOCK_N == 0, "fp8_scaled_mm_cpu: expect block_size_N to be multiples of BLOCK_N"); + TORCH_CHECK(block_size_K == BLOCK_K, "fp8_scaled_mm_cpu: expect block_size_K equals to BLOCK_K"); + CHECK_EQ(scales2.size(0), div_up(N, block_size_N)); + CHECK_EQ(scales2.size(1), div_up(K, block_size_K)); + + const auto st = mat1.scalar_type(); + TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, + "fp8_scaled_mm_cpu: expect A to be bfloat16 or half."); + TORCH_CHECK(st == out_dtype, + "fp8_scaled_mm_cpu: expect A has same dtype with out_dtype."); + TORCH_CHECK(mat2.scalar_type() == at::kFloat8_e4m3fn, + "fp8_scaled_mm_cpu: expect mat2 to be fp8_e4m3."); + TORCH_CHECK(scales2.scalar_type() == at::kFloat, + "fp8_scaled_mm_cpu: expect scales to be float32."); + auto out = at::empty({M, N}, mat1.options().dtype(out_dtype)); + + // strides + int64_t mat1_strideM = mat1.stride(0); + int64_t out_strideM = out.stride(0); + + const bool has_bias = bias.has_value(); + const float* bias_data = nullptr; + if (has_bias) { + CHECK_EQ(bias.value().size(0), N); + bias_data = bias.value().data_ptr(); + } + + // Btmp : [T, BLOCK_N * K] + // Ctmp : [T, BLOCK_M * BLOCK_N] + int num_threads = at::get_num_threads(); + int64_t size_per_thread = BLOCK_N * K + BLOCK_M * BLOCK_N * 2; + auto buffer = at::empty({num_threads, size_per_thread}, mat1.options()); + + AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "fp8_scaled_mm_kernel_impl", [&] { + fp8_scaled_mm_kernel_impl( + out.data_ptr(), + mat1.data_ptr(), + packed_w.data_ptr(), + scales2.data_ptr(), + bias_data, + buffer.data_ptr(), + M, + N, + K, + mat1_strideM, + out_strideM, + block_size_N, + block_size_K, + size_per_thread); + }); + + return out; +} diff --git a/csrc/cpu/sgl-kernels/gemm_int8.cpp b/csrc/cpu/sgl-kernels/gemm_int8.cpp new file mode 100644 index 000000000000..5a0f65a9200d --- /dev/null +++ b/csrc/cpu/sgl-kernels/gemm_int8.cpp @@ -0,0 +1,440 @@ +// Adapted from +// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu + +#include "common.h" +#include "vec.h" +#include "gemm.h" + +// clang-format off + +namespace { + +template +struct tinygemm_kernel_nn { + static inline void apply( + const uint8_t* __restrict__ A, const int8_t* __restrict__ B, scalar_t* __restrict__ C, + const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp, + const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); + } +}; + +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel_nn { + static inline void apply( + const uint8_t* __restrict__ A, const int8_t* __restrict__ B, at::BFloat16* __restrict__ C, + const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp, + const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + static_assert(COLS % 2 == 0); + + // prefetch distance + constexpr int PREFETCH_SIZE_K = 0; + + __m512i va; + __m512i vb[COLS]; + __m512i vc[ROWS * COLS]; + __m512i vcomp[COLS]; + __m512 vd0; + __m512 vd1[COLS]; + + // oops! 4x4 spills but luckly we use 4x2 + __m512 vbias[COLS]; + + // [NOTE]: s8s8 igemm compensation in avx512-vnni + // + // avx512-vnni has no s8s8, so we need to change s8s8 to u8s8 with compensate: + // + // a * b = (a + 128) * b - 128 * b + // s s u s u s + // + // 1) 128 * b is pre-computed when packing B to vnni formats + // 2) a + 128 is fused when dynamically quantize A + // + auto loadc = [&](auto i) { + vc[i] = _mm512_set1_epi32(0); + }; + Unroll{}(loadc); + + const int64_t K4 = K >> 2; + const int64_t lda4 = lda >> 2; + const int64_t ldb4 = ldb; // ldb * 4 >> 2; + const int32_t* a_ptr = reinterpret_cast(A); + const int32_t* b_ptr = reinterpret_cast(B); + + auto compute = [&](auto i, int64_t k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = _mm512_set1_epi32(a_ptr[row * lda4 + k]); + } + if constexpr (row == 0) { + vb[col] = _mm512_loadu_si512(b_ptr + k * ldb4 + col * 16); + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb4 + col * 16, _MM_HINT_T0); + } + } + vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]); + }; + for (int64_t k = 0; k < K4; ++k) { + Unroll{}(compute, k); + } + + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + // load a scale + if constexpr(col == 0) { + vd0 = _mm512_set1_ps(As[row]); + } + // load b scale and vcomp per 2 vectors + // also load bias if any + if constexpr (row == 0) { + if constexpr (col % 2 == 0) { + vd1[col + 0] = _mm512_loadu_ps(Bs + col * 16); + vd1[col + 1] = _mm512_loadu_ps(Bs + col * 16 + 16); + vcomp[col + 0] = _mm512_loadu_si512(Bcomp + col * 16); + vcomp[col + 1] = _mm512_loadu_si512(Bcomp + col * 16 + 16); + if constexpr (has_bias) { + vbias[col + 0] = _mm512_loadu_ps(bias + col * 16); + vbias[col + 1] = _mm512_loadu_ps(bias + col * 16 + 16); + } + } + } + + // for COLS = 2, 4 use 512bit store + if constexpr (col % 2 == 0) { + __m512 vc0 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[row * COLS + col + 0], vcomp[col + 0])); + __m512 vc1 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[row * COLS + col + 1], vcomp[col + 1])); + if constexpr (has_bias) { + vc0 = _mm512_fmadd_ps(_mm512_mul_ps(vc0, vd0), vd1[col + 0], vbias[col + 0]); + vc1 = _mm512_fmadd_ps(_mm512_mul_ps(vc1, vd0), vd1[col + 1], vbias[col + 1]); + } else { + vc0 = _mm512_mul_ps(_mm512_mul_ps(vc0, vd0), vd1[col + 0]); + vc1 = _mm512_mul_ps(_mm512_mul_ps(vc1, vd0), vd1[col + 1]); + } + + _mm512_storeu_si512( + reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), + (__m512i)(_mm512_cvtne2ps_pbh(vc1, vc0))); + } + }; + Unroll{}(storec); + } +}; +#endif + +#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_nn::apply( \ + A + mb_start * lda, B + nb_start * 4, C + mb_start * ldc + nb_start, \ + As + mb_start, Bs + nb_start, Bcomp + nb_start, \ + has_bias ? bias + nb_start : nullptr, K, lda, ldb, ldc); + +template +void tinygemm_kernel( + const uint8_t* __restrict__ A, + const int8_t* __restrict__ B, + scalar_t* __restrict__ C, + int32_t* __restrict__ Ctmp, + const float* __restrict__ As, + const float* __restrict__ Bs, + const float* __restrict__ bias, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + bool brg) { + + // B compensation + const int32_t* Bcomp = reinterpret_cast(B + block_size_n() * K); + + // pattern: 1-4-16 + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 64; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int64_t mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch(mb_size << 4 | nb_size >> 4) { + // mb_size = 1 + case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break; + case 0x14: LAUNCH_TINYGEMM_KERNEL_NN(1, 64); break; + // mb_size = 2 + case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break; + case 0x24: LAUNCH_TINYGEMM_KERNEL_NN(2, 64); break; + // mb_size = 3 + case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break; + case 0x34: LAUNCH_TINYGEMM_KERNEL_NN(3, 64); break; + // mb_size = 4 + case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break; + case 0x44: LAUNCH_TINYGEMM_KERNEL_NN(4, 64); break; + default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + } + } + } +} + +template +void int8_scaled_mm_kernel_impl( + scalar_t* __restrict__ out, + const uint8_t* __restrict__ mat1, + const int8_t* __restrict__ mat2, + const float* __restrict__ scales1, + const float* __restrict__ scales2, + const float* __restrict__ bias, + int64_t M, + int64_t N, + int64_t K) { + + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + + // TODO: brgemm u8s8 depends on PyTorch 2.7 release. + const bool use_brgemm = false; + + // K + 4 after compensation + const int64_t packed_row_size = get_row_size(K); + + AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] { + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + int64_t mb{0}, nb{0}; + data_index_init(begin, mb, MB, nb, NB); + + // for brgemm, use int32_t for accumulate + alignas(64) int32_t Ctmp[BLOCK_M * BLOCK_N]; + + for (int i = begin; i < end; ++i) { + UNUSED(i); + int mb_start = mb * BLOCK_M; + int mb_size = std::min(M - mb_start, BLOCK_M); + int nb_start = nb * BLOCK_N; + int nb_size = std::min(N - nb_start, BLOCK_N); + + tinygemm_kernel( + /* A */ mat1 + mb_start * K, + /* B */ mat2 + nb_start * packed_row_size /* nb * BLOCK_N * (K + 4) */, + /* C */ out + mb_start * N + nb_start, + /* Ctmp*/ Ctmp, + /* As */ scales1 + mb_start, + /* Bs */ scales2 + nb_start, + /* bias*/ bias + nb_start, + /* M */ mb_size, + /* N */ nb_size, + /* K */ K, + /* lda */ K, + /* ldb */ nb_size, + /* ldc */ N, + /* brg */ use_brgemm); + + // move to the next index + data_index_step(mb, MB, nb, NB); + } + + if (use_brgemm) { + at::native::cpublas::brgemm_release(); + } + }); + }); +} + +} // anonymous namespace + +// tinygemm interface +template +void tinygemm_kernel(const uint8_t* __restrict__ A, const int8_t* __restrict__ B, scalar_t* __restrict__ C, + int32_t* __restrict__ Ctmp, const float* __restrict__ As, const float* __restrict__ Bs, + int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool brg) { + tinygemm_kernel(A, B, C, Ctmp, As, Bs, nullptr, M, N, K, lda, ldb, ldc, brg); +} + +#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \ + template void tinygemm_kernel( \ + const uint8_t* __restrict__ A, const int8_t* __restrict__ B, TYPE* __restrict__ C, \ + int32_t* __restrict__ Ctmp, const float* __restrict__ As, const float* __restrict__ Bs, \ + int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool brg) + +INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16); +INSTANTIATE_TINYGEMM_TEMPLATE(at::Half); + +std::tuple per_token_quant_int8_cpu(at::Tensor& A) { + RECORD_FUNCTION("sgl-kernel::per_token_quant_int8_cpu", std::vector({A})); + + CHECK_LAST_DIM_CONTIGUOUS_INPUT(A); + CHECK_DIM(2, A); + + int64_t M = A.size(0); + int64_t K = A.size(1); + int64_t lda = A.stride(0); + + const auto st = A.scalar_type(); + TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, + "per_token_quant_int8: expect A to be bfloat16 or half."); + + auto Aq = at::empty({M, K}, A.options().dtype(at::kByte)); + auto As = at::empty({M}, A.options().dtype(at::kFloat)); + + AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "per_token_quant_int8", [&] { + uint8_t* __restrict__ Aq_data = Aq.data_ptr(); + float* __restrict__ As_data = As.data_ptr(); + const scalar_t* __restrict__ A_data = A.data_ptr(); + + at::parallel_for(0, M, 0, [&] (int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + quantize_row_int8( + Aq_data + m * K, + As_data[m], + A_data + m * lda, + K); + } + }); + }); + return std::make_tuple(Aq, As); +} + +// weight : static, per-channel, symmetric +// activation : dynamic, per-token, symmetric +// +// mat1 : [M, K] +// mat2 : [N, K] +// scales1 : [M] +// scales2 : [N] +// bias : [N] +// out : [M, N] +// +at::Tensor int8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2, + at::Tensor& scales1, at::Tensor& scales2, + std::optional& bias, at::ScalarType out_dtype, bool is_vnni) { + RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector({mat1, mat2, scales1, scales2, bias})); + + auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2); + + CHECK_INPUT(mat1); + CHECK_INPUT(mat2); + CHECK_INPUT(scales1); + CHECK_INPUT(scales2); + CHECK_DIM(2, mat1); + CHECK_DIM(2, mat2); + + int64_t M = mat1.size(0); + int64_t N = mat2.size(0); + int64_t K = mat1.size(1); + + // see [NOTE]: s8s8 igemm compensation in avx512-vnni + CHECK_EQ(mat2.size(1), (int64_t)(is_vnni ? K + sizeof(int32_t) : K)); + CHECK_EQ(scales1.numel(), M); + CHECK_EQ(scales2.numel(), N); + + TORCH_CHECK(mat1.scalar_type() == at::kByte, "int8_scaled_mm: expect mat1 to be uint8."); + TORCH_CHECK(mat2.scalar_type() == at::kChar, "int8_scaled_mm: expect mat2 to be int8."); + TORCH_CHECK(scales1.scalar_type() == at::kFloat && scales2.scalar_type() == at::kFloat, + "int8_scaled_mm: expect scales to be float32."); + + auto out = at::empty({M, N}, mat1.options().dtype(out_dtype)); + + const bool has_bias = bias.has_value(); + const float* bias_data = nullptr; + if (has_bias) { + CHECK_EQ(bias.value().size(0), N); + bias_data = bias.value().data_ptr(); + } + + AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "int8_scaled_mm_kernel_impl", [&] { + int8_scaled_mm_kernel_impl( + out.data_ptr(), + mat1.data_ptr(), + packed_w.data_ptr(), + scales1.data_ptr(), + scales2.data_ptr(), + bias_data, + M, + N, + K); + }); + return out; +} + +// fused `per_token_quant_int8_cpu` and `int8_scaled_mm_cpu` +at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2, at::Tensor& scales2, + const std::optional& bias, at::ScalarType out_dtype, bool is_vnni) { + RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector({mat1, mat2, scales2, bias})); + + auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2); + + CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1); + CHECK_INPUT(mat2); + CHECK_INPUT(scales2); + CHECK_DIM(2, mat1); + CHECK_DIM(2, mat2); + + int64_t M = mat1.size(0); + int64_t N = mat2.size(0); + int64_t K = mat1.size(1); + int64_t lda = mat1.stride(0); + + // see [NOTE]: s8s8 igemm compensation in avx512-vnni + CHECK_EQ(mat2.size(1), (int64_t)(is_vnni ? K + sizeof(int32_t) : K)); + CHECK_EQ(scales2.numel(), N); + + const auto st = mat1.scalar_type(); + TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, + "int8_scaled_mm_with_quant: expect A to be bfloat16 or half."); + TORCH_CHECK(st == out_dtype, + "int8_scaled_mm_with_quant: expect A has same dtype with out_dtype."); + TORCH_CHECK(mat2.scalar_type() == at::kChar, + "int8_scaled_mm_with_quant: expect mat2 to be int8."); + TORCH_CHECK(scales2.scalar_type() == at::kFloat, + "int8_scaled_mm_with_quant: expect scales to be float32."); + + const int64_t buffer_size = M * K + M * sizeof(float); + auto buffer = at::empty({buffer_size}, mat1.options().dtype(at::kByte)); + auto out = at::empty({M, N}, mat1.options().dtype(out_dtype)); + + const bool has_bias = bias.has_value(); + const float* bias_data = nullptr; + if (has_bias) { + CHECK_EQ(bias.value().size(0), N); + bias_data = bias.value().data_ptr(); + } + + AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "int8_scaled_mm_with_quant_kernel_impl", [&] { + uint8_t* __restrict__ Aq_data = buffer.data_ptr(); + float* __restrict__ As_data = (float*)((void*)(Aq_data + M * K)); + const scalar_t* __restrict__ A_data = mat1.data_ptr(); + + at::parallel_for(0, M, 0, [&] (int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + quantize_row_int8( + Aq_data + m * K, + As_data[m], + A_data + m * lda, + K); + } + }); + + int8_scaled_mm_kernel_impl( + out.data_ptr(), + Aq_data, + packed_w.data_ptr(), + As_data, + scales2.data_ptr(), + bias_data, + M, + N, + K); + }); + return out; +} diff --git a/csrc/cpu/sgl-kernels/moe.cpp b/csrc/cpu/sgl-kernels/moe.cpp new file mode 100644 index 000000000000..beeccff783ea --- /dev/null +++ b/csrc/cpu/sgl-kernels/moe.cpp @@ -0,0 +1,1330 @@ +// Adapted from +// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu + +#include "common.h" +#include "vec.h" +#include "gemm.h" + +// clang-format off + +namespace { + +// [NOTE]: Fused MoE kernel with AMX +// +// This file contains implementations for +// * `moe_align_block_size` +// * `fused_moe` +// +// The functionality is identical to triton kernel, excepts: +// * fuse silu_and_mul with gemm1, therefore this kernel +// allocates 2 intermediate_caches instead of 3 +// * add `offsets` in `moe_align_block_size` which keeps track +// of starting offset for each M block. this is for keeping +// output of silu_and_mul in sorted order, thus load_A for +// the 2nd gemm would be contiguous, therefore we can directly +// load A from intermediate_cache1. +// +// TODO: +// 1. tune BLOCK_M and BLOCK_N (BLOCK_N * K fit L2) +// 2. add prefetch for load A which is indexed access +// 3. abstract at::native::cpublas::brgemm with WoQ gemm (M = 1 & M != 1) +// + +template +inline void fill_stub(scalar_t* __restrict__ out, scalar_t val, int64_t size) { + using Vec = at::vec::Vectorized; + const Vec data_vec(val); + at::vec::map([data_vec](Vec out) { return out = data_vec; }, out, out, size); +} + +template +inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t size) { + using Vec = at::vec::Vectorized; + // no remainder + #pragma GCC unroll 4 + for (int64_t d = 0; d < size; d += Vec::size()) { + Vec data = Vec::loadu(input + d); + data.store(out + d); + } +} + +template +inline void copy_mul_stub(scalar_t* __restrict__ out, const float* __restrict__ input, float weight, int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + const fVec weight_vec = fVec(weight); + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec data0 = fVec::loadu(input + d) * weight_vec; + fVec data1 = fVec::loadu(input + d + fVec::size()) * weight_vec; + bVec out_vec = convert_from_float_ext(data0, data1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d] * weight); + } +} + +// acc from [topk, K] to [K] +template +inline void sum_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t topk, int64_t K) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + if (topk == 1) { + // do copy for topk = 1 + copy_stub(out, input, K); + } else { + // do sum for topk != 1 + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= K - kVecSize; d += kVecSize) { + fVec sum_fvec0 = fVec(0.f); + fVec sum_fvec1 = fVec(0.f); + for (int t = 0; t < topk; ++t) { + bVec x_bvec = bVec::loadu(input + t * K + d); + fVec x_fvec0, x_fvec1; + std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec); + + sum_fvec0 += x_fvec0; + sum_fvec1 += x_fvec1; + } + bVec out_bvec = convert_from_float_ext(sum_fvec0, sum_fvec1); + out_bvec.store(out + d); + } + for (; d < K; ++d) { + float sum_val = 0.f; + for (int t = 0; t < topk; ++t) { + sum_val += static_cast(input[t * K + d]); + } + out[d] = static_cast(sum_val); + } + } +} + +// out = input + input2 * scale +template +inline void add_mul_stub(scalar_t* __restrict__ out, const float* __restrict__ input, + const scalar_t* __restrict__ input2, float scale, int64_t size) { + + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + const fVec s_vec = fVec(scale); + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec x0 = fVec::loadu(input + d); + fVec x1 = fVec::loadu(input + d + fVec::size()); + + bVec y_bvec = bVec::loadu(input2 + d); + fVec y0, y1; + std::tie(y0, y1) = at::vec::convert_to_float(y_bvec); + + x0 = x0 + y0 * s_vec; + x1 = x1 + y1 * s_vec; + bVec out_vec = convert_from_float_ext(x0, x1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d] + float(input2[d]) * scale); + } +} + +template +int moe_align_block_size( + int32_t* __restrict__ sorted_ids, + int32_t* __restrict__ expert_ids, + int32_t* __restrict__ topk_ids, + int32_t* __restrict__ total_cnts, + int32_t* __restrict__ cumsums, + int32_t* __restrict__ offsets, + int num_experts, + int numel, + int num_threads) { + + #define T_INDEX(tt) total_cnts + (tt) * num_experts + + // accumulate count of expert ids locally + at::parallel_for(0, numel, 0, [&](int begin, int end) { + int tid = at::get_thread_num(); + int32_t* __restrict__ local_cnts = T_INDEX(tid + 1); + + for (int i = begin; i < end; ++i) { + local_cnts[topk_ids[i]]++; + } + }); + + using iVec = at::vec::Vectorized; + for (int t = 0; t < num_threads; ++t) { + at::vec::map2( + [](iVec x, iVec y) { return x + y; }, + T_INDEX(t + 1), T_INDEX(t + 1), T_INDEX(t), num_experts); + } + + // the last row holds sums of each experts + int32_t* total_cnts_t_1 = T_INDEX(num_threads); + + cumsums[0] = 0; + for (int e = 0; e < num_experts; ++e) { + // accumulate `num_tokens_post_pad`, also as the expert offset + cumsums[e + 1] = cumsums[e] + div_up(total_cnts_t_1[e], BLOCK_M) * BLOCK_M; + + for (int k = cumsums[e]; k < cumsums[e + 1]; k += BLOCK_M) { + expert_ids[k / BLOCK_M] = e; + } + } + int num_tokens_post_pad = cumsums[num_experts]; + + at::parallel_for(0, numel, 0, [&](int begin, int end) { + int tid = at::get_thread_num(); + // thread tid offsets in `total_cnts` + int32_t* __restrict__ offsets = T_INDEX(tid); + + for (int i = begin; i < end; ++i) { + int32_t expert_id = topk_ids[i]; + int32_t b_offset = cumsums[expert_id]; + int32_t t_offset = offsets[expert_id]; + sorted_ids[b_offset + t_offset] = i; + offsets[expert_id]++; + } + }); + + // debug: the offset for thread t_1 should be identical to t_2 + int32_t* total_cnts_t_2 = T_INDEX(num_threads - 1); + for (int e = 0; e < num_experts; ++e) { + TORCH_CHECK(total_cnts_t_1[e] == total_cnts_t_2[e]); + } + + // padding value for sorted_ids: numel + auto sorted_id_size = [=](const int32_t* sorted_ids_ptr) { + for (int d = 0; d < BLOCK_M; ++d) { + if (sorted_ids_ptr[d] == numel) { return d; } + } + return BLOCK_M; + }; + + // offsets holds starting offset for each valida M blocks + // shape : [num_token_blocks + 1] + offsets[0] = 0; + const int num_token_blocks = num_tokens_post_pad / BLOCK_M; + at::parallel_for(0, num_token_blocks, GRAIN_SIZE / BLOCK_M, [&](int begin, int end) { + for (int mb = begin; mb < end; ++mb) { + offsets[mb + 1] = sorted_id_size(sorted_ids + mb * BLOCK_M); + } + }); + // TODO: do we need to vecterize this ? + for (int mb = 0; mb < num_token_blocks; ++mb) { + offsets[mb + 1] += offsets[mb]; + } + // debug: the last value of offsets should be `numel` + TORCH_CHECK(offsets[num_token_blocks] == numel); + + return num_tokens_post_pad; +} + +// silu : shape leading dimension +// input0 [m_size, BLOCK_N] BLOCK_N +// input1 [m_size, BLOCK_N] BLOCK_N +// output [M * topk, N] N +template +inline void silu_and_mul( + scalar_t* __restrict__ output, + const float* __restrict__ input0, // x: x0, x1 + const float* __restrict__ input1, // y: y0, y1 + int64_t m_size, + int64_t N) { + + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + + const fVec one = fVec(1.f); + + // no remainder + for (int64_t m = 0; m < m_size; ++m) { + scalar_t* __restrict__ out = output + m * N; + const float* __restrict__ x = input0 + m * BLOCK_N; + const float* __restrict__ y = input1 + m * BLOCK_N; + + for (int64_t d = 0; d < BLOCK_N; d += bVec::size()) { + fVec x0 = fVec::loadu(x + d); + fVec x1 = fVec::loadu(x + d + fVec::size()); + fVec y0 = fVec::loadu(y + d); + fVec y1 = fVec::loadu(y + d + fVec::size()); + // silu + x0 = x0 / (one + x0.neg().exp_u20()); + x1 = x1 / (one + x1.neg().exp_u20()); + // mul + x0 = x0 * y0; + x1 = x1 * y1; + // convert + bVec out_vec = convert_from_float_ext(x0, x1); + out_vec.store(out + d); + } + } +} + +template +struct tinygemm_kernel_nn2 { + static inline void apply( + const scalar_t* __restrict__ A, const scalar_t* __restrict__ B0, const scalar_t* __restrict__ B1, + scalar_t* __restrict__ C, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); + } +}; + +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel_nn2 { + static inline void apply( + const at::BFloat16* __restrict__ A, const at::BFloat16* __restrict__ B0, const at::BFloat16* __restrict__ B1, + at::BFloat16* __restrict__ C, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + + static_assert(COLS % 2 == 0); + + // prefetch distance + constexpr int PREFETCH_SIZE_K = 0; + + __m512bh va; + __m512bh vb0[COLS]; + __m512bh vb1[COLS]; + __m512 vc0[ROWS * COLS]; + __m512 vc1[ROWS * COLS]; + + auto loadc = [&](auto i) { + vc0[i] = _mm512_set1_ps(0.f); + vc1[i] = _mm512_set1_ps(0.f); + }; + Unroll{}(loadc); + + const int64_t K2 = K >> 1; + const int64_t lda2 = lda >> 1; + const int64_t ldb2 = ldb; // ldb * 2 >> 1; + const float* a_ptr = reinterpret_cast(A); + const float* b0_ptr = reinterpret_cast(B0); + const float* b1_ptr = reinterpret_cast(B1); + + auto compute = [&](auto i, int64_t k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k])); + } + if constexpr (row == 0) { + vb0[col] = (__m512bh)(_mm512_loadu_si512(b0_ptr + k * ldb2 + col * 16)); + vb1[col] = (__m512bh)(_mm512_loadu_si512(b1_ptr + k * ldb2 + col * 16)); + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(b0_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0); + _mm_prefetch(b1_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0); + } + } + vc0[i] = _mm512_dpbf16_ps(vc0[i], va, vb0[col]); + vc1[i] = _mm512_dpbf16_ps(vc1[i], va, vb1[col]); + }; + for (int64_t k = 0; k < K2; ++k) { + Unroll{}(compute, k); + } + + using Vec = at::vec::Vectorized; + const Vec one = Vec(1.f); + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + // for COLS = 2, 4 use 512bit store + if constexpr (col % 2 == 0) { + Vec x0 = vc0[row * COLS + col + 0]; + Vec x1 = vc0[row * COLS + col + 1]; + Vec y0 = vc1[row * COLS + col + 0]; + Vec y1 = vc1[row * COLS + col + 1]; + // silu + x0 = x0 / (one + x0.neg().exp_u20()); + x1 = x1 / (one + x1.neg().exp_u20()); + // mul + x0 = x0 * y0; + x1 = x1 * y1; + + _mm512_storeu_si512( + reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), + (__m512i)(_mm512_cvtne2ps_pbh(__m512(x1), __m512(x0)))); + } + }; + Unroll{}(storec); + } +}; +#endif + +#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_nn2::apply( \ + A + mb_start * lda, B0 + nb_start * 2, B1 + nb_start * 2, \ + C + mb_start * ldc + nb_start, K, lda, ldb, ldc); + +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ B0, + const scalar_t* __restrict__ B1, + scalar_t* __restrict__ C, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc) { + + // pattern: 1-(2+2)-(8+8) + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 32; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch(mb_size << 4 | nb_size >> 4) { + // mb_size = 1 + case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break; + // mb_size = 2 + case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break; + // mb_size = 3 + case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break; + // mb_size = 4 + case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break; + default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + } + } + } +} + +template +struct tinygemm_kernel_nn { + static inline void apply( + const scalar_t* __restrict__ A, const scalar_t* __restrict__ B, float* __restrict__ C, + int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); + } +}; + +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel_nn { + static inline void apply( + const at::BFloat16* __restrict__ A, const at::BFloat16* __restrict__ B, float* __restrict__ C, + int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + + static_assert(COLS % 2 == 0); + + // prefetch distance + constexpr int PREFETCH_SIZE_K = 0; + + __m512bh va; + __m512bh vb[COLS]; + __m512 vc[ROWS * COLS]; + + auto loadc = [&](auto i) { + vc[i] = _mm512_set1_ps(0.f); + }; + Unroll{}(loadc); + + const int64_t K2 = K >> 1; + const int64_t lda2 = lda >> 1; + const int64_t ldb2 = ldb; // ldb * 2 >> 1; + const float* a_ptr = reinterpret_cast(A); + const float* b_ptr = reinterpret_cast(B); + + auto compute = [&](auto i, int64_t k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k])); + } + if constexpr (row == 0) { + vb[col] = (__m512bh)(_mm512_loadu_si512(b_ptr + k * ldb2 + col * 16)); + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0); + } + } + vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]); + }; + for (int64_t k = 0; k < K2; ++k) { + Unroll{}(compute, k); + } + + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + _mm512_storeu_ps(reinterpret_cast<__m512*>(C + row * ldc + col * 16), vc[i]); + + }; + Unroll{}(storec); + } +}; +#endif + +#define LAUNCH_TINYGEMM_KERNEL_NN2(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_nn::apply( \ + A + mb_start * lda, B + nb_start * 2, C + mb_start * ldc + nb_start, \ + K, lda, ldb, ldc); + +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ B, + float* __restrict__ C, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc) { + + // pattern: 1-2-8 + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 32; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch(mb_size << 4 | nb_size >> 4) { + // mb_size = 1 + case 0x12: LAUNCH_TINYGEMM_KERNEL_NN2(1, 32); break; + // mb_size = 2 + case 0x22: LAUNCH_TINYGEMM_KERNEL_NN2(2, 32); break; + // mb_size = 3 + case 0x32: LAUNCH_TINYGEMM_KERNEL_NN2(3, 32); break; + // mb_size = 4 + case 0x42: LAUNCH_TINYGEMM_KERNEL_NN2(4, 32); break; + default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + } + } + } +} + +template +void fused_experts_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic1, + scalar_t* __restrict__ ic2, + scalar_t* __restrict__ A_tmp, + float* __restrict__ C_tmp, + const scalar_t* __restrict__ input, + const scalar_t* __restrict__ packed_w1, + const scalar_t* __restrict__ packed_w2, + const float* __restrict__ topk_weights, + const int32_t* __restrict__ sorted_ids, + const int32_t* __restrict__ expert_ids, + const int32_t* __restrict__ offsets, + int64_t M, + int64_t N, + int64_t K, + int64_t E, + int64_t topk, + int64_t num_tokens_post_pad) { + + // handle 2 tiles per block + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + + // stage 1: intermediate_cache1 = silu(hidden_states @ w1) + const int64_t MB = div_up(num_tokens_post_pad, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + + // strides for w1: [E, 2N, K] + TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N); + + const int64_t stride_e = 2 * N * K; + const int64_t stride_n = K; + + // here we only parallel on half of 2N to fuse silu_and_mul with gemm + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + // get local pointers + int tid = at::get_thread_num(); + scalar_t* __restrict__ A = A_tmp + tid * BLOCK_M * K; + float* __restrict__ C0 = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; + float* __restrict__ C1 = C0 + BLOCK_M * BLOCK_N; + + bool is_brgemm_used = false; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB; + int64_t nb = i % NB; + + // nb0 from top half and nb1 from bottom half + int64_t nb0 = nb, nb1 = nb + NB; + int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N); + + // B shape [K, n_size] in vnni format + int32_t expert_id = expert_ids[mb]; + const scalar_t* __restrict__ B0 = packed_w1 + expert_id * stride_e + nb0 * BLOCK_N * stride_n; + const scalar_t* __restrict__ B1 = packed_w1 + expert_id * stride_e + nb1 * BLOCK_N * stride_n; + + // 1.a load A + const int32_t* A_ids = sorted_ids + mb * BLOCK_M; + int64_t m_size = offsets[mb + 1] - offsets[mb]; + + const bool use_brgemm = can_use_brgemm(m_size); + is_brgemm_used = is_brgemm_used || use_brgemm; + + for (int64_t m = 0; m < m_size; ++m) { + int32_t index = A_ids[m] / topk; + copy_stub(A + m * K, input + index * K, K); + } + + if (use_brgemm) { + // 1.b gemm: C0 = A @ B0 + at::native::cpublas::brgemm( + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* add_C */ false, + /* A */ A, + /* B */ B0, + /* C */ C0); + + // 1.c gemm: C1 = A @ B1 + at::native::cpublas::brgemm( + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* add_C */ false, + /* A */ A, + /* B */ B1, + /* C */ C1); + + // 1.d silu and mul + const int64_t offset = offsets[mb]; + silu_and_mul( + ic1 + offset * N + nb * BLOCK_N, + C0, + C1, + m_size, + N); + } else { + // fused 1.bcd: silu_and_mul(A @ B0, A @ B1) + const int64_t offset = offsets[mb]; + tinygemm_kernel( + /* A */ A, + /* B0 */ B0, + /* B1 */ B1, + /* C */ ic1 + offset * N + nb * BLOCK_N, + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ N); + } + } + + if (is_brgemm_used) { + at::native::cpublas::brgemm_release(); + } + }); + + // stage 2: intermediate_cache2 = intermediate_cache1 @ w2 + // w2 : [E, K, N] as [E, OC, IC] + const int64_t OC = K; // rename K as OC + const int64_t IC = N; // rename N as IC + const int64_t MB2 = MB; + const int64_t NB2 = div_up(OC, BLOCK_N); + const int64_t stride_e2 = OC * IC; + const int64_t stride_oc = IC; + + // parallel on [MB2, NB2] + at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { + // get local pointers + int tid = at::get_thread_num(); + // we won't be using C1 for gemm2 + float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; + + bool is_brgemm_used = false; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB2; + int64_t nb = i % NB2; + + int64_t m_size = offsets[mb + 1] - offsets[mb]; + int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); + + const bool use_brgemm = can_use_brgemm(m_size); + is_brgemm_used = is_brgemm_used || use_brgemm; + + // A ptr from ic1 of [M * topk, N] in sorted order + // so as to avoid copy A to tmp buffer again + const scalar_t* __restrict__ A = ic1 + offsets[mb] * N; + const int32_t* A_ids = sorted_ids + mb * BLOCK_M; + + // B shape [IC, n_size] in vnni format + int32_t expert_id = expert_ids[mb]; + const scalar_t* __restrict__ B = packed_w2 + expert_id * stride_e2 + nb * BLOCK_N * stride_oc; + + // 2.a gemm: C = A @ B + if (use_brgemm) { + at::native::cpublas::brgemm( + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* add_C */ false, + /* A */ A, + /* B */ B, + /* C */ C); + } else { + tinygemm_kernel( + /* A */ A, + /* B */ B, + /* C */ C, + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N); + } + + // 2.b copy from C to ic2 in original order + // and also mul topk_weights in float32 + for (int64_t m = 0; m < m_size; ++m) { + int32_t index = A_ids[m]; + float weight = topk_weights[index]; + copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size); + } + } + + if (is_brgemm_used) { + at::native::cpublas::brgemm_release(); + } + }); + + // stage 3: out = intermediate_cache2.sum(dim=1) + // from [M, topk, K] to [M, K] + at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + sum_stub(output + m * K, ic2 + m * topk * K, topk, K); + } + }); +} + +template +void shared_expert_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic1, + float* __restrict__ C_tmp, + scalar_t* __restrict__ input, + const scalar_t* __restrict__ packed_w1, + const scalar_t* __restrict__ packed_w2, + const scalar_t* __restrict__ fused_experts_out, + float routed_scaling_factor, + int64_t M, + int64_t N, + int64_t K) { + + // handle 2 tiles per block + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + + // stage 1: intermediate_cache1 = silu(hidden_states @ w1) + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + + TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N); + const int64_t stride_n = K; + + // here we only parallel on half of 2N to fuse silu_and_mul with gemm + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + // get local pointers + int tid = at::get_thread_num(); + float* __restrict__ C0 = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; + float* __restrict__ C1 = C0 + BLOCK_M * BLOCK_N; + + bool is_brgemm_used = false; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB; + int64_t nb = i % NB; + + // nb0 from top half and nb1 from bottom half + int64_t nb0 = nb, nb1 = nb + NB; + int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N); + int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); + + //int64_t mb_start = mb * BLOCK_M; + //int64_t mb_size = std::min(M - mb_start, BLOCK_M); + + // A shape [m_size, K] + const scalar_t* A = input + mb * BLOCK_M * K; + + // B shape [K, n_size] in vnni format + const scalar_t* __restrict__ B0 = packed_w1 + nb0 * BLOCK_N * stride_n; + const scalar_t* __restrict__ B1 = packed_w1 + nb1 * BLOCK_N * stride_n; + + const bool use_brgemm = can_use_brgemm(m_size); + is_brgemm_used = is_brgemm_used || use_brgemm; + + if (use_brgemm) { + // 1.b gemm: C0 = A @ B0 + at::native::cpublas::brgemm( + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* add_C */ false, + /* A */ A, + /* B */ B0, + /* C */ C0); + + // 1.c gemm: C1 = A @ B1 + at::native::cpublas::brgemm( + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* add_C */ false, + /* A */ A, + /* B */ B1, + /* C */ C1); + + // 1.d silu and mul + silu_and_mul( + ic1 + mb * BLOCK_M * N + nb * BLOCK_N, + C0, + C1, + m_size, + N); + } else { + // fused 1.bcd: silu_and_mul(A @ B0, A @ B1) + tinygemm_kernel( + /* A */ A, + /* B0 */ B0, + /* B1 */ B1, + /* C */ ic1 + mb * BLOCK_M * N + nb * BLOCK_N, + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ N); + } + } + + if (is_brgemm_used) { + at::native::cpublas::brgemm_release(); + } + }); + + // stage 2: output = intermediate_cache1 @ w2 + // w2 : [K, N] as [OC, IC] + const int64_t OC = K; // rename K as OC + const int64_t IC = N; // rename N as IC + const int64_t MB2 = MB; + const int64_t NB2 = div_up(OC, BLOCK_N); + const int64_t stride_oc = IC; + + // parallel on [MB2, NB2] + at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { + // get local pointers + int tid = at::get_thread_num(); + // we won't be using C1 for gemm2 + float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; + + bool is_brgemm_used = false; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB2; + int64_t nb = i % NB2; + + int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); + int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); + + const bool use_brgemm = can_use_brgemm(m_size); + is_brgemm_used = is_brgemm_used || use_brgemm; + + // A shape [m_size, IC] + const scalar_t* __restrict__ A = ic1 + mb * BLOCK_M * N; + + // B shape [IC, n_size] in vnni format + const scalar_t* __restrict__ B = packed_w2 + nb * BLOCK_N * stride_oc; + + // 2.a gemm: C = A @ B + if (use_brgemm) { + at::native::cpublas::brgemm( + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* add_C */ false, + /* A */ A, + /* B */ B, + /* C */ C); + } else { + tinygemm_kernel( + /* A */ A, + /* B */ B, + /* C */ C, + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N); + } + + // 2.b copy from C to output and add fused_experts_out + scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N; + const scalar_t* __restrict__ fused_out = fused_experts_out + mb * BLOCK_M * K + nb * BLOCK_N; + for (int64_t m = 0; m < m_size; ++m) { + add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size); + } + } + + if (is_brgemm_used) { + at::native::cpublas::brgemm_release(); + } + }); +} + +} // anonymous namespace + +// common checks +static inline void check_moe_scales( + bool use_int8_w8a8, + bool use_fp8_w8a16, + const std::optional& w1_scale, + const std::optional& w2_scale, + const std::optional> block_size, + const std::optional& a1_scale, + const std::optional& a2_scale) { + if (use_int8_w8a8) { + TORCH_CHECK(w1_scale.has_value(), "missing w1_scale for int8 w8a8."); + TORCH_CHECK(w2_scale.has_value(), "missing w2_scale for int8 w8a8."); + TORCH_CHECK(!a1_scale.has_value(), "static quantization for activation not supported."); + TORCH_CHECK(!a2_scale.has_value(), "static quantization for activation not supported."); + } + if (use_fp8_w8a16) { + TORCH_CHECK(w1_scale.has_value(), "missing w1_scale for fp8 w8a16."); + TORCH_CHECK(w2_scale.has_value(), "missing w2_scale for fp8 w8a16."); + TORCH_CHECK(block_size.has_value(), "missing block_size for fp8 w8a16."); + TORCH_CHECK(block_size.value().size() == 2, "expect block_size.size() to be 2."); + } +} + +#define CHECK_MOE_SCALES_FP8(DIM0, DIM1) \ + auto w1s = w1_scale.value(); \ + auto w2s = w2_scale.value(); \ + auto block_size_val = block_size.value(); \ + int64_t block_size_N = block_size_val[0]; \ + int64_t block_size_K = block_size_val[1]; \ + TORCH_CHECK(w1s.size(DIM0) == 2 * N / block_size_N); \ + TORCH_CHECK(w1s.size(DIM1) == K / block_size_K); \ + TORCH_CHECK(w2s.size(DIM0) == K / block_size_N); \ + TORCH_CHECK(w2s.size(DIM1) == N / block_size_K) + +// hidden_states: [M, K] +// w1: [E, 2N, K] +// w2: [E, K, N] +// topk_weights: [M, topk] +// topk_ids: [M, topk] (int32_t) +// +at::Tensor fused_experts_cpu( + at::Tensor& hidden_states, + at::Tensor& w1, + at::Tensor& w2, + at::Tensor& topk_weights, + at::Tensor& topk_ids, + bool inplace, + bool use_int8_w8a8, + bool use_fp8_w8a16, + const std::optional& w1_scale, + const std::optional& w2_scale, + const std::optional> block_size, + const std::optional& a1_scale, + const std::optional& a2_scale, + bool is_vnni) { + RECORD_FUNCTION("sgl-kernel::fused_experts_cpu", std::vector({hidden_states, w1, w2, topk_weights, topk_ids})); + + auto packed_w1 = is_vnni ? w1 : convert_weight_packed(w1); + auto packed_w2 = is_vnni ? w2 : convert_weight_packed(w2); + + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + + const auto st = hidden_states.scalar_type(); + CHECK_INPUT(hidden_states); + CHECK_INPUT(w1); + CHECK_INPUT(w2); + CHECK_EQ(topk_weights.sizes(), topk_ids.sizes()); + CHECK_DIM(2, hidden_states); + CHECK_DIM(3, w1); + CHECK_DIM(3, w2); + CHECK_DIM(2, topk_weights); + CHECK_DIM(2, topk_ids); + + CHECK_EQ(topk_ids.scalar_type(), at::kInt); + CHECK_EQ(topk_weights.scalar_type(), at::kFloat); + + int64_t M = hidden_states.size(0); + int64_t K = hidden_states.size(1); + int64_t N = w1.size(1) / 2; + int64_t E = w1.size(0); + int64_t topk = topk_weights.size(1); + + // we use int32_t compensation for int8 w8a8 + int64_t packed_K = get_row_size(K, use_int8_w8a8); + int64_t packed_N = get_row_size(N, use_int8_w8a8); + + // check weight shapes + CHECK_EQ(w2.size(0), E); + CHECK_EQ(w2.size(1), K); + CHECK_EQ(packed_w1.size(2), packed_K); + CHECK_EQ(packed_w2.size(2), packed_N); + + // check scales + check_moe_scales(use_int8_w8a8, use_fp8_w8a16, w1_scale, w2_scale, block_size, a1_scale, a2_scale); + + at::Tensor out_hidden_states = inplace ? hidden_states : at::empty_like(hidden_states); + + // NB: worst case is each expert holds a block with remainder of 1 + // 1. sorted_ids : [M * topk + E * (BLOCK_M - 1)] + // 2. expert_ids : [max_num_blocks] + // 3. total_cnts : [T + 1, E] + // 4. cumsums : [E + 1] + // 5. offsets : [max_num_blocks + 1] + // + int num_threads = at::get_num_threads(); + int64_t max_num_tokens_padded = M * topk + E * (BLOCK_M - 1); + int64_t max_num_blocks = div_up(max_num_tokens_padded, BLOCK_M); + auto buffer = at::empty( + {max_num_tokens_padded + max_num_blocks + (num_threads + 1) * E + (E + 1) + (max_num_blocks + 1)}, + topk_ids.options()); + + int32_t* __restrict__ sorted_ids = buffer.data_ptr(); + int32_t* __restrict__ expert_ids = sorted_ids + max_num_tokens_padded; + int32_t* __restrict__ total_cnts = expert_ids + max_num_blocks; + int32_t* __restrict__ cumsums = total_cnts + (num_threads + 1) * E; + int32_t* __restrict__ offsets = cumsums + (E + 1); + + // init sorted_ids with `numel` as the padding number + // init expert_ids with `num_experts` + int64_t numel = M * topk; + at::parallel_for(0, max_num_blocks, GRAIN_SIZE / BLOCK_M, [&](int64_t begin, int64_t end) { + int64_t m_start = begin * BLOCK_M; + int64_t m_size = std::min((end - begin) * BLOCK_M, max_num_tokens_padded - m_start); + fill_stub(sorted_ids + m_start, (int32_t)numel, m_size); + fill_stub(expert_ids + begin, (int32_t)E, end - begin); + }); + // zero total_cnts and cumsums + at::parallel_for(0, (num_threads + 1) * E + (E + 1), GRAIN_SIZE, [&](int64_t begin, int64_t end) { + fill_stub(total_cnts + begin, 0, end - begin); + }); + + // align experts index + int64_t num_tokens_post_pad = moe_align_block_size( + sorted_ids, expert_ids, topk_ids.data_ptr(), total_cnts, cumsums, offsets, E, numel, num_threads); + + // unlike triton kernel, we fuse silu with gemm1 so only need 2 intermediate_caches: + // 1. intermediate_cache1 : [M * topk, N] + // 2. intermediate_cache2 : [M * topk, K] + // 3. A_tmp : [T, BLOCK_M * K] + // 4. C_tmp : [T, 2 * BLOCK_M * BLOCK_N] + // + // for int8 w8a8: + // 5. Aq_tmp : [M, K] or [M * topk, N] + // 6. As_tmp : [M * topk] + // + // for fp8 w8a16: + // 7. intermediate_cache0 : [M * topk, 2N] + // 8. B_tmp : [T, BLOCK_N, std::max(K, N)] + // + int64_t buffer_size_nbytes = M * topk * N * 2 + M * topk * K * 2 + + num_threads * BLOCK_M * K * (use_int8_w8a8 ? 1 : 2) + + num_threads * 2 * BLOCK_M * BLOCK_N * sizeof(float); + + if (use_int8_w8a8) { + buffer_size_nbytes += std::max(M * K, M * topk * N) + M * topk * sizeof(float); + } + if (use_fp8_w8a16) { + buffer_size_nbytes += M * topk * 2 * N * 2 + num_threads * BLOCK_N * std::max(K, N) * 2; + } + + auto buffer2 = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar)); + + AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "fused_experts_kernel_impl", [&] { + scalar_t* __restrict__ intermediate_cache1 = (scalar_t*)((void*)(buffer2.data_ptr())); + scalar_t* __restrict__ intermediate_cache2 = intermediate_cache1 + M * topk * N; + + if (use_int8_w8a8) { + uint8_t* __restrict__ A_tmp = (uint8_t*)((void*)(intermediate_cache2 + M * topk * K)); + float* __restrict__ C_tmp = (float*)((void*)(A_tmp + num_threads * BLOCK_M * K)); + uint8_t* __restrict__ Aq_tmp = (uint8_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N)); + float* __restrict__ As_tmp = (float*)((void*)(Aq_tmp + std::max(M * K, M * topk * N))); + + auto w1s = w1_scale.value(); + auto w2s = w2_scale.value(); + TORCH_CHECK(w1s.numel() == E * 2 * N); + TORCH_CHECK(w2s.numel() == E * K); + + fused_experts_int8_kernel_impl( + out_hidden_states.data_ptr(), + intermediate_cache1, + intermediate_cache2, + A_tmp, + C_tmp, + Aq_tmp, + As_tmp, + hidden_states.data_ptr(), + packed_w1.data_ptr(), + packed_w2.data_ptr(), + w1s.data_ptr(), + w2s.data_ptr(), + topk_weights.data_ptr(), + sorted_ids, + expert_ids, + offsets, + M, + N, + K, + E, + topk, + num_tokens_post_pad); + } else if (use_fp8_w8a16) { + // here we just ignore C_tmp as it is not used + scalar_t* __restrict__ A_tmp = (scalar_t*)((void*)(intermediate_cache2 + M * topk * K)); + float* __restrict__ C_tmp = (float*)((void*)(A_tmp + num_threads * BLOCK_M * K)); + scalar_t* __restrict__ intermediate_cache0 = (scalar_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N)); + scalar_t* __restrict__ B_tmp = (scalar_t*)((void*)(intermediate_cache0 + M * topk * 2 * N)); + + CHECK_MOE_SCALES_FP8(1, 2); + fused_experts_fp8_kernel_impl( + out_hidden_states.data_ptr(), + intermediate_cache0, + intermediate_cache1, + intermediate_cache2, + A_tmp, + B_tmp, + C_tmp, + hidden_states.data_ptr(), + packed_w1.data_ptr(), + packed_w2.data_ptr(), + w1s.data_ptr(), + w2s.data_ptr(), + block_size_N, + block_size_K, + topk_weights.data_ptr(), + sorted_ids, + expert_ids, + offsets, + M, + N, + K, + E, + topk, + num_tokens_post_pad); + } else { + scalar_t* __restrict__ A_tmp = intermediate_cache2 + M * topk * K; + float* __restrict__ C_tmp = (float*)((void*)(A_tmp + num_threads * BLOCK_M * K)); + + fused_experts_kernel_impl( + out_hidden_states.data_ptr(), + intermediate_cache1, + intermediate_cache2, + A_tmp, + C_tmp, + hidden_states.data_ptr(), + packed_w1.data_ptr(), + packed_w2.data_ptr(), + topk_weights.data_ptr(), + sorted_ids, + expert_ids, + offsets, + M, + N, + K, + E, + topk, + num_tokens_post_pad); + } + }); + return out_hidden_states; +} + +// shared expert kernel +// +// hidden_states: [M, K] +// w1: [2N, K] +// w2: [K, N] +// fused_experts_out +at::Tensor shared_expert_cpu( + at::Tensor& hidden_states, + at::Tensor& w1, + at::Tensor& w2, + at::Tensor& fused_experts_out, + double routed_scaling_factor, + bool inplace, + bool use_int8_w8a8, + bool use_fp8_w8a16, + std::optional& w1_scale, + std::optional& w2_scale, + std::optional> block_size, + std::optional& a1_scale, + std::optional& a2_scale, + bool is_vnni) { + RECORD_FUNCTION("sgl-kernel::shared_expert_cpu", std::vector({hidden_states, w1, w2})); + + auto packed_w1 = is_vnni ? w1 : convert_weight_packed(w1); + auto packed_w2 = is_vnni ? w2 : convert_weight_packed(w2); + + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + + const auto st = hidden_states.scalar_type(); + CHECK_INPUT(hidden_states); + CHECK_INPUT(fused_experts_out); + CHECK_INPUT(w1); + CHECK_INPUT(w2); + CHECK_DIM(2, hidden_states); + CHECK_DIM(2, w1); + CHECK_DIM(2, w2); + CHECK_EQ(hidden_states.sizes(), fused_experts_out.sizes()); + CHECK_EQ(hidden_states.scalar_type(), st); + + int64_t M = hidden_states.size(0); + int64_t K = hidden_states.size(1); + int64_t N = w1.size(0) / 2; + + // we use int32_t compensation for int8 w8a8 + int64_t packed_K = get_row_size(K, use_int8_w8a8); + int64_t packed_N = get_row_size(N, use_int8_w8a8); + + // check weight shapes + CHECK_EQ(w2.size(0), K); + CHECK_EQ(packed_w1.size(1), packed_K); + CHECK_EQ(packed_w2.size(1), packed_N); + + // check scales + check_moe_scales(use_int8_w8a8, use_fp8_w8a16, w1_scale, w2_scale, block_size, a1_scale, a2_scale); + + at::Tensor out_hidden_states = inplace ? hidden_states : at::empty_like(hidden_states); + + // unlike triton kernel, we fuse silu with gemm1 so only need 2 intermediate_caches: + // 1. intermediate_cache1 : [M, N] + // 2. C_tmp : [T, 2 * BLOCK_M * BLOCK_N] + // + // for int8 w8a8: + // 3. Aq_tmp : [M, K] or [M, N] + // 4. As_tmp : [M] + // + // for fp8 w8a16: + // 5. intermediate_cache0 : [M, 2N] + // 6. B_tmp: [T, BLOCK_M, max(K, N)] + // + int num_threads = at::get_num_threads(); + int64_t buffer_size_nbytes = M * N * 2 + num_threads * 2 * BLOCK_M * BLOCK_N * sizeof(float); + + if (use_int8_w8a8) { + buffer_size_nbytes += std::max(M * K, M * N) + M * sizeof(float); + } + if (use_fp8_w8a16) { + buffer_size_nbytes += M * 2 * N * 2 + num_threads * BLOCK_M * std::max(K, N) * 2; + } + + auto buffer = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar)); + AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "share_experts_kernel_impl", [&] { + scalar_t* __restrict__ intermediate_cache1 = (scalar_t*)((void*)(buffer.data_ptr())); + float* __restrict__ C_tmp = (float*)((void*)(intermediate_cache1 + M * N)); + + if (use_int8_w8a8) { + uint8_t* __restrict__ Aq_tmp = (uint8_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N)); + float* __restrict__ As_tmp = (float*)((void*)(Aq_tmp + std::max(M * K, M * N))); + + auto w1s = w1_scale.value(); + auto w2s = w2_scale.value(); + TORCH_CHECK(w1s.numel() == 2 * N); + TORCH_CHECK(w2s.numel() == K); + + shared_expert_int8_kernel_impl( + out_hidden_states.data_ptr(), + intermediate_cache1, + C_tmp, + Aq_tmp, + As_tmp, + hidden_states.data_ptr(), + packed_w1.data_ptr(), + packed_w2.data_ptr(), + w1s.data_ptr(), + w2s.data_ptr(), + fused_experts_out.data_ptr(), + routed_scaling_factor, + M, + N, + K); + } else if (use_fp8_w8a16) { + scalar_t* __restrict__ intermediate_cache0 = (scalar_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N)); + scalar_t* __restrict__ B_tmp = (scalar_t*)((void*)(intermediate_cache0 + M * 2 * N)); + + CHECK_MOE_SCALES_FP8(0, 1); + shared_expert_fp8_kernel_impl( + out_hidden_states.data_ptr(), + intermediate_cache0, + intermediate_cache1, + B_tmp, + C_tmp, + hidden_states.data_ptr(), + packed_w1.data_ptr(), + packed_w2.data_ptr(), + w1s.data_ptr(), + w2s.data_ptr(), + block_size_N, + block_size_K, + fused_experts_out.data_ptr(), + routed_scaling_factor, + M, + N, + K); + } else { + shared_expert_kernel_impl( + out_hidden_states.data_ptr(), + intermediate_cache1, + C_tmp, + hidden_states.data_ptr(), + packed_w1.data_ptr(), + packed_w2.data_ptr(), + fused_experts_out.data_ptr(), + routed_scaling_factor, + M, + N, + K); + } + }); + return out_hidden_states; +} diff --git a/csrc/cpu/sgl-kernels/moe_fp8.cpp b/csrc/cpu/sgl-kernels/moe_fp8.cpp new file mode 100644 index 000000000000..84a6af267740 --- /dev/null +++ b/csrc/cpu/sgl-kernels/moe_fp8.cpp @@ -0,0 +1,502 @@ +// Adapted from +// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu + +#include "common.h" +#include "gemm.h" +#include "vec.h" + +// clang-format off + +namespace { + +template +inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t size) { + using Vec = at::vec::Vectorized; + // no remainder + #pragma GCC unroll 4 + for (int64_t d = 0; d < size; d += Vec::size()) { + Vec data = Vec::loadu(input + d); + data.store(out + d); + } +} + +template +inline void copy_mul_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, float weight, int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + const fVec weight_vec = fVec(weight); + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + bVec x = bVec::loadu(input + d); + fVec x0, x1; + std::tie(x0, x1) = at::vec::convert_to_float(x); + x0 = x0 * weight_vec; + x1 = x1 * weight_vec; + bVec out_vec = convert_from_float_ext(x0, x1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d] * weight); + } +} + +// acc from [topk, K] to [K] +template +inline void sum_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t topk, int64_t K) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + if (topk == 1) { + // do copy for topk = 1 + copy_stub(out, input, K); + } else { + // do sum for topk != 1 + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= K - kVecSize; d += kVecSize) { + fVec sum_fvec0 = fVec(0.f); + fVec sum_fvec1 = fVec(0.f); + for (int t = 0; t < topk; ++t) { + bVec x_bvec = bVec::loadu(input + t * K + d); + fVec x_fvec0, x_fvec1; + std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec); + + sum_fvec0 += x_fvec0; + sum_fvec1 += x_fvec1; + } + bVec out_bvec = convert_from_float_ext(sum_fvec0, sum_fvec1); + out_bvec.store(out + d); + } + for (; d < K; ++d) { + float sum_val = 0.f; + for (int t = 0; t < topk; ++t) { + sum_val += static_cast(input[t * K + d]); + } + out[d] = static_cast(sum_val); + } + } +} + +// out = input + input2 * scale +template +inline void add_mul_stub( + scalar_t* __restrict__ out, + const scalar_t* __restrict__ input, + const scalar_t* __restrict__ input2, + float scale, + int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + const fVec s_vec = fVec(scale); + + int64_t d; +#pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + bVec x_bvec = bVec::loadu(input + d); + fVec x0, x1; + std::tie(x0, x1) = at::vec::convert_to_float(x_bvec); + + bVec y_bvec = bVec::loadu(input2 + d); + fVec y0, y1; + std::tie(y0, y1) = at::vec::convert_to_float(y_bvec); + + x0 = x0 + y0 * s_vec; + x1 = x1 + y1 * s_vec; + bVec out_vec = convert_from_float_ext(x0, x1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d] + float(input2[d]) * scale); + } +} + +template +inline void silu_and_mul_stub( + scalar_t* __restrict__ out, + const scalar_t* __restrict__ input, + const scalar_t* __restrict__ input2, + int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + const fVec one = fVec(1.f); + + // no remainder +#pragma GCC unroll 4 + for (int64_t d = 0; d < size; d += bVec::size()) { + bVec x = bVec::loadu(input + d); + fVec x0, x1; + std::tie(x0, x1) = at::vec::convert_to_float(x); + bVec y = bVec::loadu(input2 + d); + fVec y0, y1; + std::tie(y0, y1) = at::vec::convert_to_float(y); + x0 = x0 / (one + x0.neg().exp_u20()); + x1 = x1 / (one + x1.neg().exp_u20()); + x0 = x0 * y0; + x1 = x1 * y1; + bVec out_vec = convert_from_float_ext(x0, x1); + out_vec.store(out + d); + } +} + +} // anonymous namespace + +template +void fused_experts_fp8_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic0, + scalar_t* __restrict__ ic1, + scalar_t* __restrict__ ic2, + scalar_t* __restrict__ A_tmp, + scalar_t* __restrict__ B_tmp, + float* __restrict__ C_tmp, + const scalar_t* __restrict__ input, + const at::Float8_e4m3fn* __restrict__ packed_w1, + const at::Float8_e4m3fn* __restrict__ packed_w2, + const float* __restrict__ w1s, + const float* __restrict__ w2s, + int64_t block_size_N, + int64_t block_size_K, + const float* __restrict__ topk_weights, + const int32_t* __restrict__ sorted_ids, + const int32_t* __restrict__ expert_ids, + const int32_t* __restrict__ offsets, + int64_t M, + int64_t N, + int64_t K, + int64_t E, + int64_t topk, + int64_t num_tokens_post_pad) { + + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + + // stage 1: intermediate_cache0 = hidden_states @ w1 + const int64_t MB = div_up(num_tokens_post_pad, BLOCK_M); + const int64_t NB = div_up(2 * N, BLOCK_N); + int64_t scale_size_N = div_up(2 * N, block_size_N); + int64_t scale_size_K = div_up(K, block_size_K); + int64_t blocks_n_per_group = block_size_N / BLOCK_N; + + const int64_t stride_e = 2 * N * K; + const int64_t stride_n = K; + + // here we only parallel on half of 2N to fuse silu_and_mul with gemm + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + // get local pointers + int tid = at::get_thread_num(); + scalar_t* __restrict__ A = A_tmp + tid * BLOCK_M * K; + + bool is_brgemm_used = false; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB; + int64_t nb = i % NB; + + int64_t n_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N); + + // B shape [K, n_size] in vnni format + int32_t expert_id = expert_ids[mb]; + const at::Float8_e4m3fn* __restrict__ B = packed_w1 + expert_id * stride_e + nb * BLOCK_N * stride_n; + const float* __restrict__ Bs = w1s + expert_id * scale_size_N * scale_size_K + (nb / blocks_n_per_group) * scale_size_K; + + // 1.a load A + const int32_t* A_ids = sorted_ids + mb * BLOCK_M; + int64_t m_size = offsets[mb + 1] - offsets[mb]; + + const bool use_brgemm = can_use_brgemm(m_size); + is_brgemm_used = is_brgemm_used || use_brgemm; + + for (int64_t m = 0; m < m_size; ++m) { + int32_t index = A_ids[m] / topk; + copy_stub(A + m * K, input + index * K, K); + } + + const int64_t offset = offsets[mb]; + tinygemm_kernel( + /* A */ A, + /* B */ B, + /* C */ ic0 + offset * 2 * N + nb * BLOCK_N, + /* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N), + /* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N, + /* scale */ Bs, + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ 2 * N, + /* brg */ use_brgemm, + /* block_size_K */ block_size_K); + } + + if (is_brgemm_used) { + at::native::cpublas::brgemm_release(); + } + }); + + // stage 1.5: intermediate_cache1 = silu(intermediate_cache0) + at::parallel_for(0, M * topk, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + silu_and_mul_stub( + ic1 + m * N, + ic0 + m * 2 * N, + ic0 + m * 2 * N + N, + N); + } + }); + + // stage 2: intermediate_cache2 = intermediate_cache1 @ w2 + // w2 : [E, K, N] as [E, OC, IC] + const int64_t OC = K; // rename K as OC + const int64_t IC = N; // rename N as IC + const int64_t MB2 = MB; + const int64_t NB2 = div_up(OC, BLOCK_N); + scale_size_N = div_up(K, block_size_N); + scale_size_K = div_up(N, block_size_K); + const int64_t stride_e2 = OC * IC; + const int64_t stride_oc = IC; + + // parallel on [MB2, NB2] + at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { + int tid = at::get_thread_num(); + alignas(64) scalar_t C[BLOCK_M * BLOCK_K]; + + bool is_brgemm_used = false; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB2; + int64_t nb = i % NB2; + + int64_t m_size = offsets[mb + 1] - offsets[mb]; + int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); + + const bool use_brgemm = can_use_brgemm(m_size); + is_brgemm_used = is_brgemm_used || use_brgemm; + + // A ptr from ic1 of [M * topk, N] in sorted order + // so as to avoid copy A to tmp buffer again + const scalar_t* __restrict__ A = ic1 + offsets[mb] * N; + const int32_t* A_ids = sorted_ids + mb * BLOCK_M; + + // B shape [IC, n_size] in vnni format + int32_t expert_id = expert_ids[mb]; + const at::Float8_e4m3fn* __restrict__ B = packed_w2 + expert_id * stride_e2 + nb * BLOCK_N * stride_oc; + const float* __restrict__ Bs = w2s + expert_id * scale_size_N * scale_size_K + (nb / blocks_n_per_group) * scale_size_K; + + tinygemm_kernel( + /* A */ A, + /* B */ B, + /* C */ C, + /* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N), + /* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N, + /* scale */ Bs, + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* brg */ use_brgemm, + /* block_size_K */ block_size_K); + + // 2.b copy from C to ic2 in original order + // and also mul topk_weights in float32 + for (int64_t m = 0; m < m_size; ++m) { + int32_t index = A_ids[m]; + float weight = topk_weights[index]; + copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size); + } + } + + if (is_brgemm_used) { + at::native::cpublas::brgemm_release(); + } + }); + + // stage 3: out = intermediate_cache2.sum(dim=1) + // from [M, topk, K] to [M, K] + at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + sum_stub(output + m * K, ic2 + m * topk * K, topk, K); + } + }); +} + +#define INSTANTIATE_MOE_FP8_TEMPLATE(TYPE) \ + template void fused_experts_fp8_kernel_impl( \ + TYPE* __restrict__ output, \ + TYPE* __restrict__ ic0, \ + TYPE* __restrict__ ic1, \ + TYPE* __restrict__ ic2, \ + TYPE* __restrict__ A_tmp, \ + TYPE* __restrict__ B_tmp, \ + float* __restrict__ C_tmp, \ + const TYPE* __restrict__ input, \ + const at::Float8_e4m3fn* __restrict__ packed_w1, \ + const at::Float8_e4m3fn* __restrict__ packed_w2, \ + const float* __restrict__ w1s, \ + const float* __restrict__ w2s, \ + int64_t block_size_N, \ + int64_t block_size_K, \ + const float* __restrict__ topk_weights, \ + const int32_t* __restrict__ sorted_ids, \ + const int32_t* __restrict__ expert_ids, \ + const int32_t* __restrict__ offsets, \ + int64_t M, \ + int64_t N, \ + int64_t K, \ + int64_t E, \ + int64_t topk, \ + int64_t num_tokens_post_pad) + +INSTANTIATE_MOE_FP8_TEMPLATE(at::BFloat16); +INSTANTIATE_MOE_FP8_TEMPLATE(at::Half); + +template +void shared_expert_fp8_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic0, + scalar_t* __restrict__ ic1, + scalar_t* __restrict__ B_tmp, + float* __restrict__ C_tmp, + const scalar_t* __restrict__ input, + const at::Float8_e4m3fn* __restrict__ packed_w1, + const at::Float8_e4m3fn* __restrict__ packed_w2, + const float* __restrict__ w1s, + const float* __restrict__ w2s, + int64_t block_size_N, + int64_t block_size_K, + const scalar_t* __restrict__ fused_experts_out, + float routed_scaling_factor, + int64_t M, + int64_t N, + int64_t K) { + + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + + // stage 1: intermediate_cache0 = hidden_states @ w1 + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(2 * N, BLOCK_N); + int64_t scale_size_K = div_up(K, block_size_K); + int64_t blocks_n_per_group = block_size_N / BLOCK_N; + + const bool use_brgemm = can_use_brgemm(M); + + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + int tid = at::get_thread_num(); + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB; + int64_t nb = i % NB; + int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); + int64_t n_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N); + + tinygemm_kernel( + /* A */ input + mb * BLOCK_M * K, + /* B */ packed_w1 + nb * BLOCK_N * K, + /* C */ ic0 + mb * BLOCK_M * 2 * N + nb * BLOCK_N, + /* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N), + /* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N, + /* scale */ w1s + (nb / blocks_n_per_group) * scale_size_K, + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ 2 * N, + /* brg */ use_brgemm, + /* block_size_K */ block_size_K); + } + + if (use_brgemm) { + at::native::cpublas::brgemm_release(); + } + }); + + // stage 1.5: intermediate_cache1 = silu(intermediate_cache0) + at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + silu_and_mul_stub( + ic1 + m * N, + ic0 + m * 2 * N, + ic0 + m * 2 * N + N, + N); + } + }); + + // stage 2: intermediate_cache2 = intermediate_cache1 @ w2 + // w2 : [K, N] as [OC, IC] + const int64_t OC = K; // rename K as OC + const int64_t IC = N; // rename N as IC + const int64_t MB2 = MB; + const int64_t NB2 = div_up(K, BLOCK_N); + scale_size_K = div_up(N, block_size_K); + + // parallel on [MB2, NB2] + at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { + int tid = at::get_thread_num(); + alignas(64) scalar_t C[BLOCK_M * BLOCK_K]; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB2; + int64_t nb = i % NB2; + int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); + int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); + + // 2.a gemm: C = A @ B + tinygemm_kernel( + /* A */ ic1 + mb * BLOCK_M * N, + /* B */ packed_w2 + nb * BLOCK_N * N, + /* C */ C, + /* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N), + /* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N, + /* scale */ w2s + (nb / blocks_n_per_group) * scale_size_K, + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* brg */ use_brgemm, + /* block_size_K */ block_size_K); + + // 2.b copy from C to output and add fused_experts_out + scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N; + const scalar_t* __restrict__ fused_out = fused_experts_out + mb * BLOCK_M * K + nb * BLOCK_N; + for (int64_t m = 0; m < m_size; ++m) { + add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size); + } + } + }); + + if (use_brgemm) { + at::native::cpublas::brgemm_release(); + } +} + +#define INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(TYPE) \ + template void shared_expert_fp8_kernel_impl( \ + TYPE* __restrict__ output, \ + TYPE* __restrict__ ic0, \ + TYPE* __restrict__ ic1, \ + TYPE* __restrict__ B_tmp, \ + float* __restrict__ C_tmp, \ + const TYPE* __restrict__ input, \ + const at::Float8_e4m3fn* __restrict__ packed_w1, \ + const at::Float8_e4m3fn* __restrict__ packed_w2, \ + const float* __restrict__ w1s, \ + const float* __restrict__ w2s, \ + int64_t block_size_N, \ + int64_t block_size_K, \ + const TYPE* __restrict__ fused_experts_out, \ + float routed_scaling_factor, \ + int64_t M, \ + int64_t N, \ + int64_t K) + +INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(at::BFloat16); +INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(at::Half); diff --git a/csrc/cpu/sgl-kernels/moe_int8.cpp b/csrc/cpu/sgl-kernels/moe_int8.cpp new file mode 100644 index 000000000000..89d0fb5d9f3b --- /dev/null +++ b/csrc/cpu/sgl-kernels/moe_int8.cpp @@ -0,0 +1,769 @@ +// Adapted from +// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu + +#include "common.h" +#include "vec.h" +#include "gemm.h" + +// clang-format off + +namespace { + +template +inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t size) { + using Vec = at::vec::Vectorized; + // no remainder + #pragma GCC unroll 4 + for (int64_t d = 0; d < size; d += Vec::size()) { + Vec data = Vec::loadu(input + d); + data.store(out + d); + } +} + +template <> +inline void copy_stub(uint8_t* __restrict__ out, const uint8_t* __restrict__ input, int64_t size) { + // size might be 64x + 32 + std::memcpy(out, input, size * sizeof(uint8_t)); +} + +template +inline void copy_mul_stub(scalar_t* __restrict__ out, const float* __restrict__ input, float weight, int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + const fVec weight_vec = fVec(weight); + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec data0 = fVec::loadu(input + d) * weight_vec; + fVec data1 = fVec::loadu(input + d + fVec::size()) * weight_vec; + bVec out_vec = convert_from_float_ext(data0, data1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d] * weight); + } +} + +// acc from [topk, K] to [K] +template +inline void sum_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t topk, int64_t K) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + if (topk == 1) { + // do copy for topk = 1 + copy_stub(out, input, K); + } else { + // do sum for topk != 1 + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= K - kVecSize; d += kVecSize) { + fVec sum_fvec0 = fVec(0.f); + fVec sum_fvec1 = fVec(0.f); + for (int t = 0; t < topk; ++t) { + bVec x_bvec = bVec::loadu(input + t * K + d); + fVec x_fvec0, x_fvec1; + std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec); + + sum_fvec0 += x_fvec0; + sum_fvec1 += x_fvec1; + } + bVec out_bvec = convert_from_float_ext(sum_fvec0, sum_fvec1); + out_bvec.store(out + d); + } + for (; d < K; ++d) { + float sum_val = 0.f; + for (int t = 0; t < topk; ++t) { + sum_val += static_cast(input[t * K + d]); + } + out[d] = static_cast(sum_val); + } + } +} + +// out = input + input2 * scale +template +inline void add_mul_stub(scalar_t* __restrict__ out, const float* __restrict__ input, + const scalar_t* __restrict__ input2, float scale, int64_t size) { + + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + const fVec s_vec = fVec(scale); + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec x0 = fVec::loadu(input + d); + fVec x1 = fVec::loadu(input + d + fVec::size()); + + bVec y_bvec = bVec::loadu(input2 + d); + fVec y0, y1; + std::tie(y0, y1) = at::vec::convert_to_float(y_bvec); + + x0 = x0 + y0 * s_vec; + x1 = x1 + y1 * s_vec; + bVec out_vec = convert_from_float_ext(x0, x1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d] + float(input2[d]) * scale); + } +} + +/// gemm for w13 +template +struct tinygemm_kernel_vnni { + static inline void apply( + const uint8_t* __restrict__ A, const int8_t* __restrict__ B0, const int8_t* __restrict__ B1, scalar_t* __restrict__ C, + const float* __restrict__ As, const float* __restrict__ Bs0, const float* __restrict__ Bs1, + const int32_t* __restrict__ Bcomp0, const int32_t* __restrict__ Bcomp1, + int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); + } +}; + +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel_vnni { + static inline void apply( + const uint8_t* __restrict__ A, const int8_t* __restrict__ B0, const int8_t* __restrict__ B1, at::BFloat16* __restrict__ C, + const float* __restrict__ As, const float* __restrict__ Bs0, const float* __restrict__ Bs1, + const int32_t* __restrict__ Bcomp0, const int32_t* __restrict__ Bcomp1, + int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + static_assert(COLS % 2 == 0); + + __m512i va; + __m512i vb0[COLS]; + __m512i vb1[COLS]; + __m512i vc0[ROWS * COLS]; + __m512i vc1[ROWS * COLS]; + __m512i vcomp0[COLS]; + __m512i vcomp1[COLS]; + __m512 was; + __m512 vbs0[COLS]; + __m512 vbs1[COLS]; + + auto loadc = [&](auto i) { + vc0[i] = _mm512_set1_epi32(0); + vc1[i] = _mm512_set1_epi32(0); + }; + Unroll{}(loadc); + + const int64_t K4 = K >> 2; + const int64_t lda4 = lda >> 2; + const int64_t ldb4 = ldb; // ldb * 4 >> 2; + const int32_t* a_ptr = reinterpret_cast(A); + const int32_t* b0_ptr = reinterpret_cast(B0); + const int32_t* b1_ptr = reinterpret_cast(B1); + + auto compute = [&](auto i, int64_t k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = _mm512_set1_epi32(a_ptr[row * lda4 + k]); + } + if constexpr (row == 0) { + vb0[col] = _mm512_loadu_si512(b0_ptr + k * ldb4 + col * 16); + vb1[col] = _mm512_loadu_si512(b1_ptr + k * ldb4 + col * 16); + } + vc0[i] = _mm512_dpbusd_epi32(vc0[i], va, vb0[col]); + vc1[i] = _mm512_dpbusd_epi32(vc1[i], va, vb1[col]); + }; + for (int64_t k = 0; k < K4; ++k) { + Unroll{}(compute, k); + } + + auto scalec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + // load a scale + if constexpr(col == 0) { + was = _mm512_set1_ps(As[row]); + } + // load b scale and vcomp + if constexpr (row == 0) { + vbs0[col] = _mm512_loadu_ps(Bs0 + col * 16); + vbs1[col] = _mm512_loadu_ps(Bs1 + col * 16); + vcomp0[col] = _mm512_loadu_si512(Bcomp0 + col * 16); + vcomp1[col] = _mm512_loadu_si512(Bcomp1 + col * 16); + } + __m512 c0 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc0[i], vcomp0[col])); + __m512 c1 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc1[i], vcomp1[col])); + vc0[i] = _mm512_castps_si512(_mm512_mul_ps(_mm512_mul_ps(c0, was), vbs0[col])); + vc1[i] = _mm512_castps_si512(_mm512_mul_ps(_mm512_mul_ps(c1, was), vbs1[col])); + }; + Unroll{}(scalec); + + using Vec = at::vec::Vectorized; + const Vec one = Vec(1.f); + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + // for COLS = 2, 4 use 512bit store + if constexpr (col % 2 == 0) { + Vec x0 = _mm512_castsi512_ps(vc0[row * COLS + col + 0]); + Vec x1 = _mm512_castsi512_ps(vc0[row * COLS + col + 1]); + Vec y0 = _mm512_castsi512_ps(vc1[row * COLS + col + 0]); + Vec y1 = _mm512_castsi512_ps(vc1[row * COLS + col + 1]); + // silu + x0 = x0 / (one + x0.neg().exp_u20()); + x1 = x1 / (one + x1.neg().exp_u20()); + // mul + x0 = x0 * y0; + x1 = x1 * y1; + + _mm512_storeu_si512( + reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), + (__m512i)(_mm512_cvtne2ps_pbh(__m512(x1), __m512(x0)))); + } + }; + Unroll{}(storec); + } +}; +#endif + +#define LAUNCH_TINYGEMM_KERNEL_VNNI(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_vnni::apply( \ + A + mb_start * lda, B0 + nb_start * 4, B1 + nb_start * 4, \ + C + mb_start * ldc + nb_start, As + mb_start, \ + Bs0 + nb_start, Bs1 + nb_start, Bcomp0 + nb_start, Bcomp1 + nb_start,\ + K, lda, ldb, ldc); + +template +void tinygemm_kernel( + const uint8_t* __restrict__ A, + const int8_t* __restrict__ B0, + const int8_t* __restrict__ B1, + scalar_t* __restrict__ C, + const float* __restrict__ As, + const float* __restrict__ Bs0, + const float* __restrict__ Bs1, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc) { + + const int32_t* Bcomp0 = reinterpret_cast(B0 + block_size_n() * K); + const int32_t* Bcomp1 = reinterpret_cast(B1 + block_size_n() * K); + + // pattern: 1-(2+2)-(8+8) + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 32; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch(mb_size << 4 | nb_size >> 4) { + case 0x12: LAUNCH_TINYGEMM_KERNEL_VNNI(1, 32); break; + case 0x22: LAUNCH_TINYGEMM_KERNEL_VNNI(2, 32); break; + case 0x32: LAUNCH_TINYGEMM_KERNEL_VNNI(3, 32); break; + case 0x42: LAUNCH_TINYGEMM_KERNEL_VNNI(4, 32); break; + default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + } + } + } +} + +/// gemm for w2 +template +struct tinygemm_kernel_vnni2 { + static inline void apply( + const uint8_t* __restrict__ A, const int8_t* __restrict__ B, float* __restrict__ C, + const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp, + int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); + } +}; + +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel_vnni2 { + static inline void apply( + const uint8_t* __restrict__ A, const int8_t* __restrict__ B, float* __restrict__ C, + const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp, + int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + static_assert(COLS % 2 == 0); + + __m512i va; + __m512i vb[COLS]; + __m512i vc[ROWS * COLS]; + __m512i vcomp[COLS]; + __m512 was; + __m512 vbs[COLS]; + + auto loadc = [&](auto i) { + vc[i] = _mm512_set1_epi32(0); + }; + Unroll{}(loadc); + + const int64_t K4 = K >> 2; + const int64_t lda4 = lda >> 2; + const int64_t ldb4 = ldb; // ldb * 4 >> 2; + const int32_t* a_ptr = reinterpret_cast(A); + const int32_t* b_ptr = reinterpret_cast(B); + + auto compute = [&](auto i, int64_t k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = _mm512_set1_epi32(a_ptr[row * lda4 + k]); + } + if constexpr (row == 0) { + vb[col] = _mm512_loadu_si512(b_ptr + k * ldb4 + col * 16); + } + vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]); + }; + for (int64_t k = 0; k < K4; ++k) { + Unroll{}(compute, k); + } + + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + // load a scale + if constexpr(col == 0) { + was = _mm512_set1_ps(As[row]); + } + // load b scale and vcomp per 2 vectors + // also load bias if any + if constexpr (row == 0) { + if constexpr (col % 2 == 0) { + vbs[col + 0] = _mm512_loadu_ps(Bs + col * 16); + vbs[col + 1] = _mm512_loadu_ps(Bs + col * 16 + 16); + vcomp[col + 0] = _mm512_loadu_si512(Bcomp + col * 16); + vcomp[col + 1] = _mm512_loadu_si512(Bcomp + col * 16 + 16); + } + } + __m512 x = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[i], vcomp[col])); + x = _mm512_mul_ps(_mm512_mul_ps(x, was), vbs[col]); + _mm512_storeu_ps(reinterpret_cast<__m512*>(C + row * ldc + col * 16), x); + }; + Unroll{}(storec); + } +}; +#endif + +#define LAUNCH_TINYGEMM_KERNEL_VNNI2(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_vnni2::apply( \ + A + mb_start * lda, B + nb_start * 4, C + mb_start * ldc + nb_start, \ + As + mb_start, Bs + nb_start, Bcomp + nb_start, \ + K, lda, ldb, ldc); + +template +void tinygemm_kernel( + const uint8_t* __restrict__ A, + const int8_t* __restrict__ B, + float* __restrict__ C, + const float* __restrict__ As, + const float* __restrict__ Bs, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc) { + + // B compensation + const int32_t* Bcomp = reinterpret_cast(B + block_size_n() * K); + + // pattern: 1-4-16 + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 64; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int64_t mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch(mb_size << 4 | nb_size >> 4) { + case 0x12: LAUNCH_TINYGEMM_KERNEL_VNNI2(1, 32); break; + case 0x22: LAUNCH_TINYGEMM_KERNEL_VNNI2(2, 32); break; + case 0x32: LAUNCH_TINYGEMM_KERNEL_VNNI2(3, 32); break; + case 0x42: LAUNCH_TINYGEMM_KERNEL_VNNI2(4, 32); break; + default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + } + } + } +} + +} // anonymous namespace + +template +void fused_experts_int8_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic1, + scalar_t* __restrict__ ic2, + uint8_t* __restrict__ A_tmp, + float* __restrict__ C_tmp, + uint8_t* __restrict__ Aq_tmp, + float* __restrict__ As_tmp, + const scalar_t* __restrict__ input, + const int8_t* __restrict__ packed_w1, + const int8_t* __restrict__ packed_w2, + const float* __restrict__ w1s, + const float* __restrict__ w2s, + const float* __restrict__ topk_weights, + const int32_t* __restrict__ sorted_ids, + const int32_t* __restrict__ expert_ids, + const int32_t* __restrict__ offsets, + int64_t M, + int64_t N, + int64_t K, + int64_t E, + int64_t topk, + int64_t num_tokens_post_pad) { + + // handle 2 tiles per block + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + + // stage 0: quantize input to uint8, [M, K] + at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + quantize_row_int8( + Aq_tmp + m * K, + As_tmp[m], + input + m * K, + K); + } + }); + + // stage 1: intermediate_cache1 = silu(hidden_states @ w1) + const int64_t MB = div_up(num_tokens_post_pad, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + + // strides for w1: [E, 2N, K] + TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N); + + // K and N are packed for int8 + const int64_t packed_K = get_row_size(K); + const int64_t packed_N = get_row_size(N); + + const int64_t stride_e = 2 * N * packed_K; + const int64_t stride_n = packed_K; + // here we only parallel on half of 2N to fuse silu_and_mul with gemm + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + // get local pointers + int tid = at::get_thread_num(); + uint8_t* __restrict__ A = A_tmp + tid * BLOCK_M * K; + + alignas(64) float As[BLOCK_M]; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB; + int64_t nb = i % NB; + + // nb0 from top half and nb1 from bottom half + int64_t nb0 = nb, nb1 = nb + NB; + int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N); + + // B shape [K, n_size] in vnni format + int32_t expert_id = expert_ids[mb]; + const int8_t* __restrict__ B0 = packed_w1 + expert_id * stride_e + nb0 * BLOCK_N * stride_n; + const int8_t* __restrict__ B1 = packed_w1 + expert_id * stride_e + nb1 * BLOCK_N * stride_n; + const float* __restrict__ Bs0 = w1s + expert_id * 2 * N + nb0 * BLOCK_N; + const float* __restrict__ Bs1 = w1s + expert_id * 2 * N + nb1 * BLOCK_N; + + // 1.a load A + const int32_t* A_ids = sorted_ids + mb * BLOCK_M; + int64_t m_size = offsets[mb + 1] - offsets[mb]; + + for (int64_t m = 0; m < m_size; ++m) { + int32_t index = A_ids[m] / topk; + copy_stub(A + m * K, Aq_tmp + index * K, K); + As[m] = As_tmp[index]; + } + + // fused 1.b: silu_and_mul(A @ B0, A @ B1) + const int64_t offset = offsets[mb]; + tinygemm_kernel( + /* A */ A, + /* B0 */ B0, + /* B1 */ B1, + /* C */ ic1 + offset * N + nb * BLOCK_N, + /* As */ As, + /* Bs0 */ Bs0, + /* Bs1 */ Bs1, + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ N); + } + }); + + // stage 1.5: quantize ic1 to uint8, [M * topk, N] + at::parallel_for(0, M * topk, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + quantize_row_int8( + Aq_tmp + m * N, + As_tmp[m], + ic1 + m * N, + N); + } + }); + + // stage 2: intermediate_cache2 = intermediate_cache1 @ w2 + // w2 : [E, K, N] as [E, OC, IC] + const int64_t OC = K; // rename K as OC + const int64_t IC = N; // rename N as IC + const int64_t MB2 = MB; + const int64_t NB2 = div_up(OC, BLOCK_N); + const int64_t stride_e2 = OC * packed_N; + const int64_t stride_oc = packed_N; + + // parallel on [MB2, NB2] + at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { + // get local pointers + int tid = at::get_thread_num(); + // we won't be using C1 for gemm2 + float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB2; + int64_t nb = i % NB2; + + int64_t m_size = offsets[mb + 1] - offsets[mb]; + int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); + + // A ptr from ic1 of [M * topk, N] in sorted order + // so as to avoid copy A to tmp buffer again + const uint8_t* __restrict__ A = Aq_tmp + offsets[mb] * N; + const float* __restrict__ As = As_tmp + offsets[mb]; + const int32_t* A_ids = sorted_ids + mb * BLOCK_M; + + // B shape [IC, n_size] in vnni format + int32_t expert_id = expert_ids[mb]; + const int8_t* __restrict__ B = packed_w2 + expert_id * stride_e2 + nb * BLOCK_N * stride_oc; + const float* __restrict__ Bs = w2s + expert_id * K + nb * BLOCK_N; + + // 2.a gemm: C = A @ B + tinygemm_kernel( + /* A */ A, + /* B */ B, + /* C */ C, + /* As */ As, + /* Bs */ Bs, + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N); + + // 2.b copy from C to ic2 in original order + // and also mul topk_weights in float32 + for (int64_t m = 0; m < m_size; ++m) { + int32_t index = A_ids[m]; + float weight = topk_weights[index]; + copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size); + } + } + }); + + // stage 3: out = intermediate_cache2.sum(dim=1) + // from [M, topk, K] to [M, K] + at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + sum_stub(output + m * K, ic2 + m * topk * K, topk, K); + } + }); +} + +#define INSTANTIATE_MOE_INT8_TEMPLATE(TYPE) \ + template void fused_experts_int8_kernel_impl ( \ + TYPE* __restrict__ output, TYPE* __restrict__ ic1, \ + TYPE* __restrict__ ic2, uint8_t* __restrict__ A_tmp, \ + float* __restrict__ C_tmp, uint8_t* __restrict__ Aq_tmp, \ + float* __restrict__ As_tmp, const TYPE* __restrict__ input, \ + const int8_t* __restrict__ packed_w1, const int8_t* __restrict__ packed_w2, \ + const float* __restrict__ w1s, const float* __restrict__ w2s, \ + const float* __restrict__ topk_weights, const int32_t* __restrict__ sorted_ids, \ + const int32_t* __restrict__ expert_ids, const int32_t* __restrict__ offsets, \ + int64_t M, int64_t N, int64_t K, int64_t E, int64_t topk, int64_t num_tokens_post_pad) + +INSTANTIATE_MOE_INT8_TEMPLATE(at::BFloat16); +INSTANTIATE_MOE_INT8_TEMPLATE(at::Half); + +template +void shared_expert_int8_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic1, + float* __restrict__ C_tmp, + uint8_t* __restrict__ Aq_tmp, + float* __restrict__ As_tmp, + const scalar_t* __restrict__ input, + const int8_t* __restrict__ packed_w1, + const int8_t* __restrict__ packed_w2, + const float* __restrict__ w1s, + const float* __restrict__ w2s, + const scalar_t* __restrict__ fused_experts_out, + float routed_scaling_factor, + int64_t M, + int64_t N, + int64_t K) { + + // handle 2 tiles per block + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + + // stage 0: quantize input to uint8, [M, K] + at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + quantize_row_int8( + Aq_tmp + m * K, + As_tmp[m], + input + m * K, + K); + } + }); + + // stage 1: intermediate_cache1 = silu(hidden_states @ w1) + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + + TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N); + + // K and N are packed for int8 + const int64_t packed_K = get_row_size(K); + const int64_t packed_N = get_row_size(N); + const int64_t stride_n = packed_K; + + // here we only parallel on half of 2N to fuse silu_and_mul with gemm + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB; + int64_t nb = i % NB; + + // nb0 from top half and nb1 from bottom half + int64_t nb0 = nb, nb1 = nb + NB; + int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N); + int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); + + // A shape [m_size, K] + const uint8_t* A = Aq_tmp + mb * BLOCK_M * K; + const float* As = As_tmp + mb * BLOCK_M; + + // B shape [K, n_size] in vnni format + const int8_t* __restrict__ B0 = packed_w1 + nb0 * BLOCK_N * stride_n; + const int8_t* __restrict__ B1 = packed_w1 + nb1 * BLOCK_N * stride_n; + const float* __restrict__ Bs0 = w1s + nb0 * BLOCK_N; + const float* __restrict__ Bs1 = w1s + nb1 * BLOCK_N; + + // fused 1.b: silu_and_mul(A @ B0, A @ B1) + tinygemm_kernel( + /* A */ A, + /* B0 */ B0, + /* B1 */ B1, + /* C */ ic1 + mb * BLOCK_M * N + nb * BLOCK_N, + /* As */ As, + /* Bs0 */ Bs0, + /* Bs1 */ Bs1, + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ N); + } + }); + + // stage 1.5: quantize ic1 to uint8, [M * topk, N] + at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + quantize_row_int8( + Aq_tmp + m * N, + As_tmp[m], + ic1 + m * N, + N); + } + }); + + // stage 2: intermediate_cache2 = intermediate_cache1 @ w2 + // w2 : [K, N] as [OC, IC] + const int64_t OC = K; // rename K as OC + const int64_t IC = N; // rename N as IC + const int64_t MB2 = MB; + const int64_t NB2 = div_up(OC, BLOCK_N); + const int64_t stride_oc = packed_N; + + // parallel on [MB2, NB2] + at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { + // get local pointers + int tid = at::get_thread_num(); + // we won't be using C1 for gemm2 + float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB2; + int64_t nb = i % NB2; + + int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); + int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); + + // A shape [m_size, IC] + const uint8_t* __restrict__ A = Aq_tmp + mb * BLOCK_M * N; + const float* __restrict__ As = As_tmp + mb * BLOCK_M; + + // B shape [IC, n_size] in vnni format + const int8_t* __restrict__ B = packed_w2 + nb * BLOCK_N * stride_oc; + const float* __restrict__ Bs = w2s + nb * BLOCK_N; + + // 2.a gemm: C = A @ B + tinygemm_kernel( + /* A */ A, + /* B */ B, + /* C */ C, + /* As */ As, + /* Bs */ Bs, + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N); + + // 2.b copy from C to output and add fused_experts_out + scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N; + const scalar_t* __restrict__ fused_out = fused_experts_out + mb * BLOCK_M * K + nb * BLOCK_N; + for (int64_t m = 0; m < m_size; ++m) { + add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size); + } + } + }); +} + +#define INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(TYPE) \ + template void shared_expert_int8_kernel_impl ( \ + TYPE* __restrict__ output, TYPE* __restrict__ ic1, \ + float* __restrict__ C_tmp, uint8_t* __restrict__ Aq_tmp, \ + float* __restrict__ As_tmp, const TYPE* __restrict__ input, \ + const int8_t* __restrict__ packed_w1, const int8_t* __restrict__ packed_w2, \ + const float* __restrict__ w1s, const float* __restrict__ w2s, \ + const TYPE* __restrict__ fused_experts_out, float routed_scaling_factor, \ + int64_t M, int64_t N, int64_t K) + +INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(at::BFloat16); +INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(at::Half); diff --git a/csrc/cpu/sgl-kernels/vec.h b/csrc/cpu/sgl-kernels/vec.h new file mode 100644 index 000000000000..87955cfb2922 --- /dev/null +++ b/csrc/cpu/sgl-kernels/vec.h @@ -0,0 +1,308 @@ +// Adapted from +// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu + +#pragma once + +// clang-format off + +#if defined(__AVX512F__) && defined(__AVX512BF16__) && defined(__AMX_BF16__) +#define CPU_CAPABILITY_AVX512 +#endif + +#include +#include + +namespace { + +using namespace at::vec; + +template , int> = 0> +inline Vectorized convert_from_float_ext(const Vectorized& a, const Vectorized& b) { + return at::vec::convert_from_float(a, b); +} + +#if defined(CPU_CAPABILITY_AVX512) + +// `at::vec::convert_from_float<>` from PyTorch doesn't have avx512-bf16 intrinsics +// use native instruction for bfloat16->float32 conversion +template <> +inline Vectorized convert_from_float_ext(const Vectorized& a, const Vectorized& b) { + return (__m512i)(_mm512_cvtne2ps_pbh(__m512(b), __m512(a))); +} + +#define CVT_BF16_TO_FP32(a) \ + _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16)) + +#define CVT_FP16_TO_FP32(a) \ + _mm512_cvtps_ph(a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) + +// this doesn't hanel NaN. +inline __m512bh cvt_e4m3_bf16_intrinsic_no_nan(__m256i fp8_vec) { + const __m512i x = _mm512_cvtepu8_epi16(fp8_vec); + + const __m512i mant = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x07)), 4); + const __m512i raw_exp = _mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x78)), 3); + const __m512i exp = _mm512_slli_epi16(_mm512_add_epi16(raw_exp, _mm512_set1_epi16(120)), 7); + const __m512i nonsign = _mm512_or_si512(exp, mant); + + const __m512i sign = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x80)), 8); + const __m512i combined = _mm512_or_si512(nonsign, sign); + + const __mmask32 is_nonzero = _mm512_cmpneq_epi16_mask(x, _mm512_setzero_si512()); + return (__m512bh)_mm512_maskz_mov_epi16(is_nonzero, combined); +} + +inline __m512bh cvt_e4m3_bf16_intrinsic_without_denorm(__m256i fp8_vec) { + // The following conversion is without denorm behavior, that is to say, + // Max subnorm : S.0000.111 = 0.875 ∗ 2**(−6) + // Min subnorm : S.0000.001 = 2**(−9) + // 0.0019 ~ 0.0137 cannot be converted correctly. + __m512i x = _mm512_cvtepu8_epi16(fp8_vec); + auto mask = _mm512_cmpneq_epi16_mask( + _mm512_and_si512(x, _mm512_set1_epi16(127)), + _mm512_setzero_si512()); // mask = x & 0x7f + auto mask_nan = _mm512_cmpneq_epi16_mask( + _mm512_and_si512(x, _mm512_set1_epi16(127)), + _mm512_set1_epi16(127)); // mask_nan = x & 0x7f + auto mantissa = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(7)), 4); // mantissa = (x & 7) << 4 + auto exponent = _mm512_add_epi16( + _mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(120)), 3), + _mm512_set1_epi16(120)); // exponent = (((x >> 3) & 15) + 120) + auto nonsign = _mm512_maskz_mov_epi16(mask, _mm512_or_si512(mantissa, _mm512_slli_epi16(exponent, 7))); + nonsign = _mm512_mask_mov_epi16(_mm512_set1_epi16(0x7fff), mask_nan, nonsign); // deal with Nan + return (__m512bh)(_mm512_or_si512( + nonsign, + _mm512_slli_epi16( + _mm512_and_si512(x, _mm512_set1_epi16(128)), + 8))); // add sign (x & 128) << 8 +} + +inline __m512bh cvt_e4m3_bf16_intrinsic_with_denorm(__m256i fp8_vec) { + __m512i x = _mm512_cvtepu8_epi16(fp8_vec); + __m512i lg2mant = _mm512_mask_mov_epi16( + _mm512_mask_mov_epi16( + _mm512_setzero_si512(), _mm512_test_epi16_mask(x, _mm512_set1_epi16(2)), _mm512_set1_epi16(1)), + _mm512_test_epi16_mask(x, _mm512_set1_epi16(4)), + _mm512_set1_epi16(2)); + return (__m512bh)(_mm512_or_si512( + _mm512_maskz_mov_epi16( + _mm512_cmpneq_epi16_mask(_mm512_and_si512(x, _mm512_set1_epi16(127)), _mm512_setzero_si512()), + _mm512_mask_blend_epi16( + _mm512_test_epi16_mask(x, _mm512_set1_epi16(120)), + _mm512_or_si512( + _mm512_and_si512( + _mm512_sllv_epi16( + _mm512_and_si512(x, _mm512_set1_epi16(3)), _mm512_sub_epi16(_mm512_set1_epi16(7), lg2mant)), + _mm512_set1_epi16(0x007f)), + _mm512_slli_epi16(_mm512_add_epi16(lg2mant, _mm512_set1_epi16(118)), 7)), + _mm512_or_si512( + _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(7)), 4), + _mm512_slli_epi16( + _mm512_add_epi16( + _mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(120)), 3), _mm512_set1_epi16(120)), + 7)))), + _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(128)), 8))); +} + +inline __m512bh CVT_FP8_TO_BF16(__m256i a) { +#ifdef SGLANG_CPU_FP8_CVT_FTZ + return cvt_e4m3_bf16_intrinsic_no_nan(a); +#else + return cvt_e4m3_bf16_intrinsic_with_denorm(a); +#endif +} + +#endif + +// vector to scalar reduction +#if defined(CPU_CAPABILITY_AVX512) && 0 +inline float vec_reduce_sum(const Vectorized& a) { + return _mm512_reduce_add_ps(__m512(a)); +} + +inline float vec_reduce_max(const Vectorized& a) { + return _mm512_reduce_max_ps(__m512(a)); +} +#else +inline float vec_reduce_sum(const Vectorized& a) { + return vec_reduce_all([](Vectorized& x, Vectorized& y) { return x + y; }, a); +} + +inline float vec_reduce_max(const Vectorized& a) { + return vec_reduce_all([](Vectorized& x, Vectorized& y) { return maximum(x, y); }, a); +} +#endif + +// https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282 +template +inline void quantize_row_int8(uint8_t* __restrict__ Aq, float& As, + const scalar_t* __restrict__ A, int64_t K, float eps = 1e-7) { + + float amax = 0.f; // absolute max + for (int64_t k = 0; k < K; ++k) { + const float val = static_cast(A[k]); + amax = std::max(amax, std::abs(val)); + } + + amax = std::max(amax, eps); + const float scale = amax / 127; + const float inv_scale = 127 / amax; + + for (int64_t k = 0; k < K; ++k) { + const float val = static_cast(A[k]) * inv_scale; + Aq[k] = (uint8_t)(std::round(val)) + 128; + } + As = scale; +} + +#if defined(CPU_CAPABILITY_AVX512) +template <> +inline void quantize_row_int8(uint8_t* __restrict__ Aq, float& As, + const at::BFloat16* __restrict__ A, int64_t K, float eps) { + + const __m512 signBit = _mm512_set1_ps(-0.0f); + const __m512i off = _mm512_set1_epi32(128); + + // K is 32x, no remainder + float amax = 0.f; + __m512 vamax0 = _mm512_set1_ps(0.f); + __m512 vamax1 = _mm512_set1_ps(0.f); + for (int64_t k = 0; k < K; k += 32) { + __m512i va = _mm512_loadu_si512((void*)(A + k)); + __m512 va0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 0)); + __m512 va1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 1)); + vamax0 = _mm512_max_ps(vamax0, _mm512_andnot_ps(signBit, va0)); + vamax1 = _mm512_max_ps(vamax1, _mm512_andnot_ps(signBit, va1)); + } + amax = _mm512_reduce_max_ps(_mm512_max_ps(vamax0, vamax1)); + amax = std::max(amax, eps); + const float scale = amax / 127; + const float inv_scale = 127 / amax; + const __m512 vd = _mm512_set1_ps(inv_scale); + + for (int64_t k = 0; k < K; k += 32) { + __m512i va = _mm512_loadu_si512((void*)(A + k)); + __m512 va0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 0)); + __m512 va1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 1)); + va0 = _mm512_mul_ps(va0, vd); + va1 = _mm512_mul_ps(va1, vd); + va0 = _mm512_roundscale_ps(va0, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + va1 = _mm512_roundscale_ps(va1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m128i i0 = _mm512_cvtepi32_epi8(_mm512_add_epi32(_mm512_cvtps_epi32(va0), off)); + __m128i i1 = _mm512_cvtepi32_epi8(_mm512_add_epi32(_mm512_cvtps_epi32(va1), off)); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(Aq + k), _mm256_set_m128i(i1, i0)); + } + As = scale; +} +#endif + +// transpose utils +// taken from my PR in ggml: https://github.com/ggml-org/llama.cpp/pull/8998 +#if defined(CPU_CAPABILITY_AVX512) +inline void transpose_16x16_32bit(__m512i * v) { + __m512i v1[16]; + v1[0] = _mm512_unpacklo_epi32(v[0], v[1]); + v1[1] = _mm512_unpackhi_epi32(v[0], v[1]); + v1[2] = _mm512_unpacklo_epi32(v[2], v[3]); + v1[3] = _mm512_unpackhi_epi32(v[2], v[3]); + v1[4] = _mm512_unpacklo_epi32(v[4], v[5]); + v1[5] = _mm512_unpackhi_epi32(v[4], v[5]); + v1[6] = _mm512_unpacklo_epi32(v[6], v[7]); + v1[7] = _mm512_unpackhi_epi32(v[6], v[7]); + v1[8] = _mm512_unpacklo_epi32(v[8], v[9]); + v1[9] = _mm512_unpackhi_epi32(v[8], v[9]); + v1[10] = _mm512_unpacklo_epi32(v[10], v[11]); + v1[11] = _mm512_unpackhi_epi32(v[10], v[11]); + v1[12] = _mm512_unpacklo_epi32(v[12], v[13]); + v1[13] = _mm512_unpackhi_epi32(v[12], v[13]); + v1[14] = _mm512_unpacklo_epi32(v[14], v[15]); + v1[15] = _mm512_unpackhi_epi32(v[14], v[15]); + + v[0] = _mm512_unpacklo_epi64(v1[0], v1[2]); + v[1] = _mm512_unpackhi_epi64(v1[0], v1[2]); + v[2] = _mm512_unpacklo_epi64(v1[1], v1[3]); + v[3] = _mm512_unpackhi_epi64(v1[1], v1[3]); + v[4] = _mm512_unpacklo_epi64(v1[4], v1[6]); + v[5] = _mm512_unpackhi_epi64(v1[4], v1[6]); + v[6] = _mm512_unpacklo_epi64(v1[5], v1[7]); + v[7] = _mm512_unpackhi_epi64(v1[5], v1[7]); + v[8] = _mm512_unpacklo_epi64(v1[8], v1[10]); + v[9] = _mm512_unpackhi_epi64(v1[8], v1[10]); + v[10] = _mm512_unpacklo_epi64(v1[9], v1[11]); + v[11] = _mm512_unpackhi_epi64(v1[9], v1[11]); + v[12] = _mm512_unpacklo_epi64(v1[12], v1[14]); + v[13] = _mm512_unpackhi_epi64(v1[12], v1[14]); + v[14] = _mm512_unpacklo_epi64(v1[13], v1[15]); + v[15] = _mm512_unpackhi_epi64(v1[13], v1[15]); + + v1[0] = _mm512_shuffle_i32x4(v[0], v[4], 0x88); + v1[1] = _mm512_shuffle_i32x4(v[1], v[5], 0x88); + v1[2] = _mm512_shuffle_i32x4(v[2], v[6], 0x88); + v1[3] = _mm512_shuffle_i32x4(v[3], v[7], 0x88); + v1[4] = _mm512_shuffle_i32x4(v[0], v[4], 0xdd); + v1[5] = _mm512_shuffle_i32x4(v[1], v[5], 0xdd); + v1[6] = _mm512_shuffle_i32x4(v[2], v[6], 0xdd); + v1[7] = _mm512_shuffle_i32x4(v[3], v[7], 0xdd); + v1[8] = _mm512_shuffle_i32x4(v[8], v[12], 0x88); + v1[9] = _mm512_shuffle_i32x4(v[9], v[13], 0x88); + v1[10] = _mm512_shuffle_i32x4(v[10], v[14], 0x88); + v1[11] = _mm512_shuffle_i32x4(v[11], v[15], 0x88); + v1[12] = _mm512_shuffle_i32x4(v[8], v[12], 0xdd); + v1[13] = _mm512_shuffle_i32x4(v[9], v[13], 0xdd); + v1[14] = _mm512_shuffle_i32x4(v[10], v[14], 0xdd); + v1[15] = _mm512_shuffle_i32x4(v[11], v[15], 0xdd); + + v[0] = _mm512_shuffle_i32x4(v1[0], v1[8], 0x88); + v[1] = _mm512_shuffle_i32x4(v1[1], v1[9], 0x88); + v[2] = _mm512_shuffle_i32x4(v1[2], v1[10], 0x88); + v[3] = _mm512_shuffle_i32x4(v1[3], v1[11], 0x88); + v[4] = _mm512_shuffle_i32x4(v1[4], v1[12], 0x88); + v[5] = _mm512_shuffle_i32x4(v1[5], v1[13], 0x88); + v[6] = _mm512_shuffle_i32x4(v1[6], v1[14], 0x88); + v[7] = _mm512_shuffle_i32x4(v1[7], v1[15], 0x88); + v[8] = _mm512_shuffle_i32x4(v1[0], v1[8], 0xdd); + v[9] = _mm512_shuffle_i32x4(v1[1], v1[9], 0xdd); + v[10] = _mm512_shuffle_i32x4(v1[2], v1[10], 0xdd); + v[11] = _mm512_shuffle_i32x4(v1[3], v1[11], 0xdd); + v[12] = _mm512_shuffle_i32x4(v1[4], v1[12], 0xdd); + v[13] = _mm512_shuffle_i32x4(v1[5], v1[13], 0xdd); + v[14] = _mm512_shuffle_i32x4(v1[6], v1[14], 0xdd); + v[15] = _mm512_shuffle_i32x4(v1[7], v1[15], 0xdd); +} + +// remove warning : ignoring attributes on template argument ‘__m512i’ [-Wignored-attributes] +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wignored-attributes" + +// transpose from [2, 32] to [32, 2] +inline std::tuple<__m512i, __m512i> transpose_2x32_16bit(__m512i r0, __m512i r1) { + // r0: {a0, a1, ..., a31} + // r1: {b0, b1, ..., b31} + // + // d0: {a0, b0, ..., a15, b15} + // d1: {a16, b16, ..., a31, b31} + // + __m512i d0 = _mm512_unpacklo_epi16(r0, r1); + __m512i d1 = _mm512_unpackhi_epi16(r0, r1); + r0 = _mm512_shuffle_i32x4(d0, d1, 0x88); + r1 = _mm512_shuffle_i32x4(d0, d1, 0xdd); + d0 = _mm512_shuffle_i32x4(r0, r1, 0x88); + d1 = _mm512_shuffle_i32x4(r0, r1, 0xdd); + return std::make_tuple(d0, d1); +} +#pragma GCC diagnostic pop + +#endif + +// TODO: debug print, remove me later +template +void print_array(scalar_t* ptr, int size) { + for (int d = 0; d < size; ++d) { + if (d % 16 == 0) { std::cout << std::endl; } + std::cout << ptr[d] << " "; + } + std::cout << std::endl; +} + +} // anonymous namespace diff --git a/csrc/cpu/shm.cpp b/csrc/cpu/shm.cpp index f55e96de251d..9adb6f27ec41 100644 --- a/csrc/cpu/shm.cpp +++ b/csrc/cpu/shm.cpp @@ -7,9 +7,10 @@ namespace { #define MAX_SHM_RANK_NUM 8 -#define MAX_THREAD_NUM 12 -#define PER_THREAD_SHM_BUFFER_BYTES (4 * 1024 * 1024) -#define MIN_THREAD_PROCESS_SIZE (8 * 1024) +#define PER_THREAD_SHM_BUFFER_BYTES (2 * 1024 * 1024) +static_assert(PER_THREAD_SHM_BUFFER_BYTES % 2 == 0); +#define PER_THREAD_SHM_BUFFER_OFFSET (PER_THREAD_SHM_BUFFER_BYTES >> 1) +#define MIN_THREAD_PROCESS_SIZE (256) #define MAX_P2P_SEND_TENSOR_NUM 8 template @@ -32,10 +33,10 @@ struct KernelVecType { using scalar_vec_t = vec_op::FP16Vec16; }; -enum class ThreadSHMStat : char { THREAD_READY = 0, SHM_DATA_READY, DONE }; - struct ThreadSHMContext { - volatile ThreadSHMStat thread_stats[MAX_SHM_RANK_NUM]; + volatile char _curr_thread_stamp; + volatile char _ready_thread_stamp; + char _padding1[6]; int thread_id; int thread_num; int rank; @@ -44,14 +45,19 @@ struct ThreadSHMContext { int swizzled_ranks[MAX_SHM_RANK_NUM]; void* thread_shm_ptrs[MAX_SHM_RANK_NUM]; ThreadSHMContext* shm_contexts[MAX_SHM_RANK_NUM]; + size_t _thread_buffer_mask; + char _padding2[56]; ThreadSHMContext(const int thread_id, const int thread_num, const int rank, const int group_size, void* thread_shm_ptr) - : thread_id(thread_id), + : _curr_thread_stamp(1), + _ready_thread_stamp(0), + thread_id(thread_id), thread_num(thread_num), rank(rank), group_size(group_size), - _spinning_count(0) { + _spinning_count(0), + _thread_buffer_mask(0) { static_assert(sizeof(ThreadSHMContext) % 64 == 0); TORCH_CHECK(group_size <= MAX_SHM_RANK_NUM); TORCH_CHECK((size_t)this % 64 == 0); @@ -60,7 +66,6 @@ struct ThreadSHMContext { shm_contexts[i] = nullptr; thread_shm_ptrs[i] = nullptr; swizzled_ranks[i] = (i + rank) % group_size; - thread_stats[i] = ThreadSHMStat::DONE; } set_context(rank, this, thread_shm_ptr); } @@ -77,59 +82,66 @@ struct ThreadSHMContext { template T* get_thread_shm_ptr(int rank) { - return reinterpret_cast(thread_shm_ptrs[rank]); + return reinterpret_cast( + reinterpret_cast(thread_shm_ptrs[rank]) + + (PER_THREAD_SHM_BUFFER_OFFSET & _thread_buffer_mask)); + } + + void next_buffer() { _thread_buffer_mask ^= 0xFFFFFFFFFFFFFFFF; } + + char get_curr_stamp() const { return _curr_thread_stamp; } + + char get_ready_stamp() const { return _ready_thread_stamp; } + + void next_stamp() { + _mm_mfence(); + _curr_thread_stamp += 1; + } + + void commit_ready_stamp() { + _mm_mfence(); + _ready_thread_stamp = _curr_thread_stamp; } int get_swizzled_rank(int idx) { return swizzled_ranks[idx]; } - void wait_for_all(ThreadSHMStat prev_stat) { - for (int idx = 0; idx < group_size; ++idx) { + template + void wait_for_all(Cond&& cond) { + for (int idx = 1; idx < group_size; ++idx) { int rank = get_swizzled_rank(idx); - while (thread_stats[rank] == prev_stat) { - ++_spinning_count; - _mm_pause(); - } + wait_for_one(rank, std::forward(cond)); } - vec_op::mem_barrier(); } - void wait_for_one(int rank, ThreadSHMStat prev_stat) { - while (thread_stats[rank] == prev_stat) { + template + void wait_for_one(int rank, Cond&& cond) { + ThreadSHMContext* rank_ctx = shm_contexts[rank]; + for (;;) { + char local_curr_stamp = get_curr_stamp(); + char local_ready_stamp = get_ready_stamp(); + char rank_curr_stamp = rank_ctx->get_curr_stamp(); + char rank_ready_stamp = rank_ctx->get_ready_stamp(); + if (cond(local_curr_stamp, local_ready_stamp, rank_curr_stamp, + rank_ready_stamp)) { + break; + } ++_spinning_count; _mm_pause(); } - vec_op::mem_barrier(); } - void set_thread_stat(ThreadSHMStat stat) { - for (int idx = 0; idx < group_size; ++idx) { - int rank = get_swizzled_rank(idx); - shm_contexts[rank]->thread_stats[this->rank] = stat; - } + static bool check_no_buffer_conflict(char local_curr_stamp, + char local_ready_stamp, + char rank_curr_stamp, + char rank_ready_stamp) { + char temp = rank_curr_stamp + 2; + return local_curr_stamp != temp; } - void set_thread_stat(int target_rank, ThreadSHMStat stat) { - for (int idx = 0; idx < group_size; ++idx) { - int rank = get_swizzled_rank(idx); - shm_contexts[rank]->thread_stats[target_rank] = stat; - } - } - - // barrier for all ranks in the group, used for all2all ops - // DONE -> THREAD_READY -> SHM_DATA_READY -> DONE -> ... - void barrier(ThreadSHMStat next_stat) { - if (next_stat == ThreadSHMStat::THREAD_READY) { - set_thread_stat(ThreadSHMStat::THREAD_READY); - wait_for_all(ThreadSHMStat::DONE); - } else if (next_stat == ThreadSHMStat::SHM_DATA_READY) { - set_thread_stat(ThreadSHMStat::SHM_DATA_READY); - wait_for_all(ThreadSHMStat::THREAD_READY); - } else if (next_stat == ThreadSHMStat::DONE) { - set_thread_stat(ThreadSHMStat::DONE); - wait_for_all(ThreadSHMStat::SHM_DATA_READY); - } else { - TORCH_CHECK(false, "Invalid next_stat to barrier."); - } + static bool check_stamp_ready(char local_curr_stamp, char local_ready_stamp, + char rank_curr_stamp, char rank_ready_stamp) { + char temp = local_curr_stamp + 1; + return (local_curr_stamp == rank_ready_stamp) || (temp == rank_ready_stamp); } std::string to_string() const { @@ -164,7 +176,7 @@ class SHMManager { const int group_size) : _rank(rank), _group_size(group_size), - _thread_num(std::min(torch::get_num_threads(), MAX_THREAD_NUM)), + _thread_num(torch::get_num_threads()), _shm_names({""}), _shared_mem_ptrs({nullptr}), _shm_ctx(nullptr) { @@ -326,7 +338,8 @@ void shm_cc_loop(ThreadSHMContext* ctx, int64_t elem_num, F&& inner_func) { (total_units_num + thread_num - 1) / thread_num; int64_t per_unit_elem_num = MIN_THREAD_PROCESS_SIZE / sizeof(scalar_t); int64_t max_per_thread_iteration_elem_num = - PER_THREAD_SHM_BUFFER_BYTES / sizeof(scalar_t); + (PER_THREAD_SHM_BUFFER_BYTES >> 1) / + sizeof(scalar_t); // Note: double buffer int64_t per_thread_elem_num = per_unit_elem_num * per_thread_units_num; #pragma omp parallel for schedule(static, 1) @@ -336,10 +349,13 @@ void shm_cc_loop(ThreadSHMContext* ctx, int64_t elem_num, F&& inner_func) { int64_t curr_elem_num = std::min(max_per_thread_iteration_elem_num, end - offset); ThreadSHMContext* thread_ctx = ctx + i; + bool fast_mode = ((end - offset) <= max_per_thread_iteration_elem_num); while (curr_elem_num > 0) { - inner_func(thread_ctx, offset, curr_elem_num); + inner_func(thread_ctx, offset, curr_elem_num, fast_mode); + thread_ctx->next_stamp(); + thread_ctx->next_buffer(); offset += max_per_thread_iteration_elem_num; curr_elem_num = std::min(max_per_thread_iteration_elem_num, end - offset); } @@ -397,7 +413,7 @@ void all_reduce_sum_impl(ThreadSHMContext* ctx, scalar_t* data, shm_cc_ops::shm_cc_loop( ctx, elem_num, [&](ThreadSHMContext* thread_ctx, int64_t data_offset, - int64_t data_elem_num) { + int64_t data_elem_num, bool fast_mode) { int rank = thread_ctx->rank; scalar_t* thread_shm_ptr = thread_ctx->get_thread_shm_ptr(rank); @@ -410,16 +426,17 @@ void all_reduce_sum_impl(ThreadSHMContext* ctx, scalar_t* data, thread_ctx->get_swizzled_rank(idx + 1)); }); - thread_ctx->barrier(ThreadSHMStat::THREAD_READY); + if (!fast_mode) { + thread_ctx->wait_for_all(ThreadSHMContext::check_no_buffer_conflict); + } shm_cc_ops::memcpy_to_shm(thread_shm_ptr, thread_data_ptr, thread_data_elem_num); - - thread_ctx->barrier(ThreadSHMStat::SHM_DATA_READY); - + thread_ctx->commit_ready_stamp(); int64_t aligned_data_elem_num = (data_elem_num / vec_elem_num) * vec_elem_num; int64_t i = 0; + thread_ctx->wait_for_all(ThreadSHMContext::check_stamp_ready); #pragma GCC unroll 4 for (; i < aligned_data_elem_num; i += vec_elem_num) { vec_t local_data(thread_data_ptr + i); // load from cache @@ -447,8 +464,6 @@ void all_reduce_sum_impl(ThreadSHMContext* ctx, scalar_t* data, reduced_data.save(thread_data_ptr + i, data_elem_num - aligned_data_elem_num); } - - thread_ctx->barrier(ThreadSHMStat::DONE); }); return; @@ -488,18 +503,18 @@ void shm_gather_impl(ThreadSHMContext* ctx, scalar_t* data, size_t elem_num, shm_cc_ops::shm_cc_loop( ctx, elem_num, [&](ThreadSHMContext* thread_ctx, int64_t data_offset, - int64_t data_elem_num) { + int64_t data_elem_num, bool fast_mode) { int rank = thread_ctx->rank; scalar_t* thread_shm_ptr = thread_ctx->get_thread_shm_ptr(rank); - thread_ctx->barrier(ThreadSHMStat::THREAD_READY); - - shm_cc_ops::memcpy_to_shm(thread_shm_ptr, data + data_offset, - data_elem_num * sizeof(scalar_t)); - - thread_ctx->barrier(ThreadSHMStat::SHM_DATA_READY); + if (!fast_mode) { + thread_ctx->wait_for_all(ThreadSHMContext::check_no_buffer_conflict); + } + shm_cc_ops::memcpy(thread_shm_ptr, data + data_offset, + data_elem_num * sizeof(scalar_t)); + thread_ctx->commit_ready_stamp(); if (rank == dst) { shm_cc_ops::memcpy(outputs[rank] + data_offset, data + data_offset, data_elem_num * sizeof(scalar_t)); @@ -508,12 +523,12 @@ void shm_gather_impl(ThreadSHMContext* ctx, scalar_t* data, size_t elem_num, scalar_t* src_ptr = thread_ctx->get_thread_shm_ptr(src_rank); // shm scalar_t* dst_ptr = outputs[src_rank] + data_offset; - shm_cc_ops::memcpy_from_shm(dst_ptr, src_ptr, - data_elem_num * sizeof(scalar_t)); + thread_ctx->wait_for_one(src_rank, + ThreadSHMContext::check_stamp_ready); + shm_cc_ops::memcpy(dst_ptr, src_ptr, + data_elem_num * sizeof(scalar_t)); } } - - thread_ctx->barrier(ThreadSHMStat::DONE); }); return; @@ -599,7 +614,7 @@ struct TensorListMeta { int8_t _padding[40]; }; -void shm_send_tensor_list_impl(ThreadSHMContext* ctx, +void shm_send_tensor_list_impl(ThreadSHMContext* ctx, int64_t dst, const std::vector& tensor_list) { CPU_KERNEL_GUARD_IN(shm_send_tensor_list_impl) std::vector tensor_list_with_metadata; @@ -620,12 +635,11 @@ void shm_send_tensor_list_impl(ThreadSHMContext* ctx, shm_cc_ops::shm_cc_loop( ctx, metadata->total_bytes, [&](ThreadSHMContext* thread_ctx, int64_t data_offset, - int64_t data_elem_num) { + int64_t data_elem_num, bool fast_mode) { int rank = thread_ctx->rank; - // Wait until the receiver set the stat to DONE - thread_ctx->wait_for_one(rank, ThreadSHMStat::SHM_DATA_READY); - int64_t curr_shm_offset = 0; + thread_ctx->wait_for_one(dst, + ThreadSHMContext::check_no_buffer_conflict); while (curr_shm_offset < data_elem_num) { MemPiece frag = metadata->get_data(data_offset + curr_shm_offset); frag.size = std::min(frag.size, data_elem_num - curr_shm_offset); @@ -634,8 +648,7 @@ void shm_send_tensor_list_impl(ThreadSHMContext* ctx, frag.ptr, frag.size); curr_shm_offset += frag.size; } - - thread_ctx->set_thread_stat(rank, ThreadSHMStat::SHM_DATA_READY); + thread_ctx->commit_ready_stamp(); }); } @@ -646,8 +659,7 @@ std::vector shm_recv_tensor_list_impl(ThreadSHMContext* ctx, torch::Tensor metadata_tensor = torch::empty({sizeof(TensorListMeta)}, options); - // Wait until the sender set the stat of the thread 0 to SHM_DATA_READY - ctx->wait_for_one(src, ThreadSHMStat::DONE); + ctx->wait_for_one(src, ThreadSHMContext::check_stamp_ready); shm_cc_ops::memcpy(metadata_tensor.data_ptr(), ctx->get_thread_shm_ptr(src), sizeof(TensorListMeta)); @@ -664,9 +676,8 @@ std::vector shm_recv_tensor_list_impl(ThreadSHMContext* ctx, shm_cc_ops::shm_cc_loop( ctx, metadata.total_bytes, [&](ThreadSHMContext* thread_ctx, int64_t data_offset, - int64_t data_elem_num) { - // Wait until the sender set the stat to SHM_DATA_READY - thread_ctx->wait_for_one(src, ThreadSHMStat::DONE); + int64_t data_elem_num, bool fast_mode) { + ctx->wait_for_one(src, ThreadSHMContext::check_stamp_ready); int64_t curr_shm_offset = 0; while (curr_shm_offset < data_elem_num) { MemPiece frag = metadata.get_data(data_offset + curr_shm_offset); @@ -677,8 +688,6 @@ std::vector shm_recv_tensor_list_impl(ThreadSHMContext* ctx, frag.size); curr_shm_offset += frag.size; } - - thread_ctx->set_thread_stat(src, ThreadSHMStat::DONE); }); std::vector tensor_list; @@ -756,7 +765,8 @@ void shm_send_tensor_list(int64_t handle, int64_t dst) { CPU_KERNEL_GUARD_IN(shm_send_tensor_list) shm_send_tensor_list_impl( - SHMManager::get_singleton_instance(handle)->get_shm_ctx(), tensor_list); + SHMManager::get_singleton_instance(handle)->get_shm_ctx(), dst, + tensor_list); CPU_KERNEL_GUARD_OUT(shm_send_tensor_list) } @@ -778,4 +788,4 @@ std::string join_shm_manager(int64_t handle, const std::string& name) { TORCH_CHECK(shm_manager); shm_manager->join(name); return shm_manager->get_shm_ctx()->to_string(); -} \ No newline at end of file +} diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 60304d229a8f..ebfc81f85836 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -50,6 +50,27 @@ void shm_send_tensor_list(int64_t handle, std::vector shm_recv_tensor_list(int64_t handle, int64_t src); +at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, + const std::optional& bias, + bool is_vnni); + +at::Tensor convert_weight_packed(at::Tensor& weight); + +at::Tensor fused_experts_cpu( + at::Tensor& hidden_states, at::Tensor& w1, at::Tensor& w2, + at::Tensor& topk_weights, at::Tensor& topk_ids, bool inplace, + bool use_int8_w8a8, bool use_fp8_w8a16, + const std::optional& w1_scale, + const std::optional& w2_scale, + const std::optional> block_size, + const std::optional& a1_scale, + const std::optional& a2_scale, bool is_vnni); + +at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2, + at::Tensor& scales2, + const std::optional& bias, + at::ScalarType out_dtype, bool is_vnni); + TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // vLLM custom ops @@ -214,6 +235,28 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("shm_recv_tensor_list(int handle, int src) -> Tensor[](a)", &shm_recv_tensor_list); #endif + + // sgl-kernels +#if defined(__AVX512BF16__) && defined(__AVX512F__) && defined(__AVX512VNNI__) + ops.def( + "weight_packed_linear(Tensor(a0!) mat1, Tensor(a1!) mat2, Tensor(a2!)? " + "bias, bool is_vnni) -> Tensor"); + ops.impl("weight_packed_linear", torch::kCPU, &weight_packed_linear); + ops.def("convert_weight_packed(Tensor! weight) -> Tensor"); + ops.impl("convert_weight_packed", torch::kCPU, &convert_weight_packed); + ops.def( + "fused_experts_cpu(Tensor! hidden_states, Tensor w1, Tensor w2, Tensor " + "topk_weights, Tensor topk_ids, bool inplace, bool use_int8_w8a8, bool " + "use_fp8_w8a16, Tensor? w1_scale, Tensor? w2_scale, SymInt[]? " + "block_size, Tensor? a1_scale, Tensor? a2_scale, bool is_vnni) -> " + "Tensor"); + ops.impl("fused_experts_cpu", torch::kCPU, &fused_experts_cpu); + ops.def( + "int8_scaled_mm_with_quant(Tensor mat1, Tensor mat2, Tensor scales2, " + "Tensor? bias, ScalarType out_dtype, bool is_vnni) -> Tensor"); + ops.impl("int8_scaled_mm_with_quant", torch::kCPU, + &int8_scaled_mm_with_quant); +#endif } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { diff --git a/docs/getting_started/installation/cpu.md b/docs/getting_started/installation/cpu.md index 370b854def0f..5f2d0dbe27d3 100644 --- a/docs/getting_started/installation/cpu.md +++ b/docs/getting_started/installation/cpu.md @@ -118,6 +118,7 @@ vLLM CPU backend supports the following vLLM features: - `VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads. For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores. By setting to `auto`, the OpenMP threads of each rank are bound to the CPU cores in each NUMA node. By setting to `all`, the OpenMP threads of each rank uses all CPU cores available on the system. Default value is `auto`. - `VLLM_CPU_NUM_OF_RESERVED_CPU`: specify the number of CPU cores which are not dedicated to the OpenMP threads for each rank. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. Default value is `0`. - `VLLM_CPU_MOE_PREPACK`: whether to use prepack for MoE layer. This will be passed to `ipex.llm.modules.GatedMLPMOE`. Default is `1` (True). On unsupported CPUs, you might need to set this to `0` (False). +- `VLLM_CPU_SGL_KERNEL` (Experimental): whether to use small-batch optimized kernels for linear layer and MoE layer, especially for low-latency requirements like online serving. The kernels require AMX instruction set, BFloat16 weight type and weight shapes divisible by 32. Default is `0` (False). ## Performance tips diff --git a/tests/models/language/generation/test_common.py b/tests/models/language/generation/test_common.py index f656f90c4bd3..7d7a62eec118 100644 --- a/tests/models/language/generation/test_common.py +++ b/tests/models/language/generation/test_common.py @@ -78,7 +78,7 @@ AITER_MODEL_LIST = [ ), pytest.param( "Qwen/Qwen2.5-0.5B-Instruct", # qwen2 - marks=[pytest.mark.core_model], + marks=[pytest.mark.core_model, pytest.mark.cpu_model], ), pytest.param( "Qwen/Qwen3-8B", # qwen (text-only) @@ -87,6 +87,7 @@ AITER_MODEL_LIST = [ pytest.param("bigcode/starcoder2-3b"), # starcoder2 pytest.param( "TitanML/tiny-mixtral", # mixtral + marks=[pytest.mark.core_model, pytest.mark.cpu_model], ) ]) @pytest.mark.parametrize("max_tokens", [32]) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 51900de1cc09..36a0395ccdc9 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1850,3 +1850,52 @@ def cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor, torch.ops._C.cutlass_mla_decode(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, scale) return out + + +if hasattr(torch.ops._C, "weight_packed_linear"): + + @register_fake("_C::weight_packed_linear") + def weight_packed_linear_fake(mat1: torch.Tensor, mat2: torch.Tensor, + bias: Optional[torch.Tensor], + is_vnni: bool) -> torch.Tensor: + return torch.empty((mat1.size(0), mat2.size(0)), + dtype=mat1.dtype, + device=mat2.device) + + +if hasattr(torch.ops._C, "fused_experts_cpu"): + + @register_fake("_C::fused_experts_cpu") + def fused_experts_cpu_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool, + use_int8_w8a8: bool, + use_fp8_w8a16: bool, + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + block_size: Optional[list[int]], + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + is_vnni: bool, + ) -> torch.Tensor: + return torch.empty_like(hidden_states) + + +if hasattr(torch.ops._C, "int8_scaled_mm_with_quant"): + + @register_fake("_C::int8_scaled_mm_with_quant") + def int8_scaled_mm_with_quant_fake( + mat1: torch.Tensor, + mat2: torch.Tensor, + scales2: torch.Tensor, + bias: Optional[torch.Tensor], + out_dtype: torch.dtype, + is_vnni: bool, + ) -> torch.Tensor: + M = mat1.size(0) + N = mat2.size(0) + return torch.empty((M, N), dtype=out_dtype) diff --git a/vllm/envs.py b/vllm/envs.py index a3f19c7ee5c7..c73dbb0a8446 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -46,6 +46,7 @@ if TYPE_CHECKING: VLLM_CPU_OMP_THREADS_BIND: str = "" VLLM_CPU_NUM_OF_RESERVED_CPU: int = 0 VLLM_CPU_MOE_PREPACK: bool = True + VLLM_CPU_SGL_KERNEL: bool = False VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache") VLLM_XLA_CHECK_RECOMPILATION: bool = False VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024 @@ -447,6 +448,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_CPU_MOE_PREPACK": lambda: bool(int(os.getenv("VLLM_CPU_MOE_PREPACK", "1"))), + # (CPU backend only) whether to use SGL kernels, optimized for small batch. + "VLLM_CPU_SGL_KERNEL": + lambda: bool(int(os.getenv("VLLM_CPU_SGL_KERNEL", "0"))), + # If the env var is set, then all workers will execute as separate # processes from the engine, and we use the same mechanism to trigger # execution on all workers. diff --git a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py new file mode 100644 index 000000000000..68ce6bcccb5d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py @@ -0,0 +1,214 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Callable, Optional + +import torch + +from vllm import envs + + +class IPEXFusedMOE: + + def __init__(self, layer: torch.nn.Module) -> None: + import intel_extension_for_pytorch as ipex + layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE( + layer.w13_weight, + layer.w2_weight, + use_prepack=envs.VLLM_CPU_MOE_PREPACK, + ) + + def __call__( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + assert activation == "silu", f"{activation} is not supported." + assert not apply_router_weight_on_input + return layer.ipex_fusion( + x, + use_grouped_topk, + top_k, + router_logits, + renormalize, + topk_group, + num_expert_group, + custom_routing_function, + scoring_func, + e_score_correction_bias, + ) + + +class SGLFusedMOE: + + def __init__(self, layer: torch.nn.Module) -> None: + pass + + @staticmethod + def _grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None + ) -> tuple[torch.Tensor, torch.Tensor]: + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + + gating_output = gating_output.float() + if scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + num_token = scores.shape[0] + if e_score_correction_bias is not None: + # Store original scores before applying correction bias. We use + # biased scores for expert selection but original scores for + # routing weights + original_scores = scores + scores = scores + e_score_correction_bias.unsqueeze(0) + group_scores = (scores.view(num_token, num_expert_group, + -1).topk(2, dim=-1)[0].sum(dim=-1)) + else: + group_scores = scores.view(num_token, num_expert_group, + -1).max(dim=-1).values # [n, n_group] + group_idx = torch.topk(group_scores, + k=topk_group, + dim=-1, + sorted=False)[1] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = group_mask.unsqueeze(-1).expand( + num_token, num_expert_group, + scores.shape[-1] // num_expert_group).reshape(num_token, + -1) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), + float("-inf")) # [n, e] + + if e_score_correction_bias is not None: + topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] + # Use original unbiased scores for the routing weights + topk_weights = original_scores.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(tmp_scores, + k=topk, + dim=-1, + sorted=False) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, + keepdim=True) + + return topk_weights, topk_ids.to(torch.int32) + + @staticmethod + def _select_experts( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # DeekSeekv2 uses grouped_top_k + if use_grouped_topk: + assert topk_group is not None + assert num_expert_group is not None + topk_weights, topk_ids = SGLFusedMOE._grouped_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + elif custom_routing_function is None: + assert scoring_func == "softmax" + topk_weights = torch.nn.functional.softmax(router_logits, + dim=1, + dtype=torch.float32) + topk_weights, topk_ids = torch.topk(topk_weights, top_k, dim=-1) + if renormalize: + topk_weights /= topk_weights.sum(dim=-1, keepdim=True) + topk_ids = topk_ids.to(torch.int32) + else: + topk_weights, topk_ids = custom_routing_function( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize) + + return topk_weights, topk_ids + + def __call__( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + assert activation == "silu", f"{activation} is not supported." + assert not apply_router_weight_on_input + topk_weights, topk_ids = SGLFusedMOE._select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) + + torch.ops._C.fused_experts_cpu( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + True, + False, + False, + None, + None, + None, + None, + None, + True, + ) + return x diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index e6f555d315d8..d6ead084af99 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -550,12 +550,23 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): if current_platform.is_cpu(): if current_platform.get_cpu_architecture() == CpuArchEnum.X86: - import intel_extension_for_pytorch as ipex - layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE( - layer.w13_weight, - layer.w2_weight, - use_prepack=envs.VLLM_CPU_MOE_PREPACK, - ) + from vllm.model_executor.layers.fused_moe import cpu_fused_moe + dtype = layer.w13_weight.dtype + if (envs.VLLM_CPU_SGL_KERNEL + and torch._C._cpu._is_amx_tile_supported() + and dtype == torch.bfloat16): + packed_w13_weight = torch.ops._C.convert_weight_packed( + layer.w13_weight) + assert packed_w13_weight.size() == layer.w13_weight.size() + layer.w13_weight.copy_(packed_w13_weight) + del packed_w13_weight + packed_w2_weight = torch.ops._C.convert_weight_packed( + layer.w2_weight) + assert packed_w2_weight.size() == layer.w2_weight.size() + layer.w2_weight.copy_(packed_w2_weight) + layer.cpu_fused_moe = cpu_fused_moe.SGLFusedMOE(layer) + else: + layer.cpu_fused_moe = cpu_fused_moe.IPEXFusedMOE(layer) else: raise NotImplementedError("CPU MOE only supports x86 arch.") @@ -673,13 +684,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, - activation: str = "silu", apply_router_weight_on_input: bool = False, + activation: str = "silu", **kwargs, ): - assert activation == "silu", f"{activation} is not supported." - assert apply_router_weight_on_input is False - return layer.ipex_fusion( + return layer.cpu_fused_moe( + layer, x, use_grouped_topk, top_k, @@ -687,9 +697,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): renormalize, topk_group, num_expert_group, + global_num_experts, + expert_map, custom_routing_function, scoring_func, e_score_correction_bias, + apply_router_weight_on_input, + activation, ) def forward_hpu( @@ -764,7 +778,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): expert_map=expert_map, renormalize=renormalize) - forward_native = forward_tpu if current_platform.is_tpu() else forward_cuda + if current_platform.is_tpu(): + forward_native = forward_tpu + elif current_platform.is_cpu(): + forward_native = forward_cpu + else: + forward_native = forward_cuda def determine_expert_map( diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 588aa8deb183..a05ae0edbd77 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -9,6 +9,7 @@ import torch import torch.nn as nn from torch.nn.parameter import Parameter, UninitializedParameter +from vllm import envs from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, @@ -27,6 +28,7 @@ from vllm.model_executor.parameter import (BasevLLMParameter, RowvLLMParameter) # yapf: enable from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform logger = init_logger(__name__) @@ -195,12 +197,33 @@ class UnquantizedLinearMethod(LinearMethodBase): layer.register_parameter("weight", weight) set_weight_attrs(weight, extra_weight_attrs) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if current_platform.is_cpu() and envs.VLLM_CPU_SGL_KERNEL: + N, K = layer.weight.size() + dtype = layer.weight.dtype + if (torch._C._cpu._is_amx_tile_supported() + and dtype == torch.bfloat16 and N % 32 == 0 + and K % 32 == 0): + packed_weight = torch.ops._C.convert_weight_packed( + layer.weight) + assert packed_weight.size() == layer.weight.size() + layer.weight.copy_(packed_weight) + if layer.bias is not None: + layer.bias = Parameter(layer.bias.to(torch.float32), + requires_grad=False) + layer.use_cpu_sgl = True + else: + logger.warning( + "CPU SGL kernels require Intel AMX support," + " bfloat16 weight, IC and OC are divisible by 32.") + layer.use_cpu_sgl = False + def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return dispatch_unquantized_gemm()(x, layer.weight, bias) + return dispatch_unquantized_gemm()(layer, x, layer.weight, bias) class LinearBase(torch.nn.Module): diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 41b5253dca04..ad4ba9c0b827 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -63,7 +63,15 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, return logits -def rocm_unquantized_gemm(x: torch.Tensor, +def default_unquantized_gemm(layer: torch.nn.Module, + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None): + return torch.nn.functional.linear(x, weight, bias) + + +def rocm_unquantized_gemm(layer: torch.nn.Module, + x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None): from vllm.platforms.rocm import on_gfx9 @@ -89,7 +97,20 @@ def rocm_unquantized_gemm(x: torch.Tensor, return torch.nn.functional.linear(x, weight, bias) +def cpu_unquantized_gemm(layer: torch.nn.Module, + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None): + if getattr(layer, "use_cpu_sgl", False): + return torch.ops._C.weight_packed_linear(x, weight, bias, True) + else: + return torch.nn.functional.linear(x, weight, bias) + + def dispatch_unquantized_gemm() -> Callable[..., torch.Tensor]: if current_platform.is_rocm(): return rocm_unquantized_gemm - return torch.nn.functional.linear + elif current_platform.is_cpu(): + return cpu_unquantized_gemm + else: + return default_unquantized_gemm diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 9ff3a7a7327d..f35f969781bd 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -43,7 +43,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase): layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return dispatch_unquantized_gemm()(x, layer.weight, bias) + return dispatch_unquantized_gemm()(layer, x, layer.weight, bias) def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor: diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 106bce162003..dccd60f4463a 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -194,6 +194,8 @@ class CpuPlatform(Platform): "epilogue_fusion": True, }) + if compilation_config.use_inductor: + compilation_config.custom_ops = ["none"] if vllm_config.lora_config is not None: compilation_config.level = CompilationLevel.NO_COMPILATION