From 9705fba7b727a3b9c275b012258608531e2223d1 Mon Sep 17 00:00:00 2001 From: Fadi Arafeh <115173828+fadara01@users.noreply.github.com> Date: Sat, 4 Oct 2025 05:16:38 +0100 Subject: [PATCH] [cpu][perf] Accelerate unquantized-linear for AArch64 through oneDNN/ACL and weight prepack (#25948) Signed-off-by: Fadi Arafeh Co-authored-by: Li, Jiang --- cmake/cpu_extension.cmake | 3 +- csrc/cpu/dnnl_helper.cpp | 80 +++++++++++++++++++++++++---- csrc/cpu/dnnl_helper.h | 2 +- csrc/cpu/dnnl_kernels.cpp | 23 ++++++++- csrc/cpu/torch_bindings.cpp | 5 ++ setup.py | 5 ++ vllm/_custom_ops.py | 4 ++ vllm/model_executor/layers/utils.py | 5 +- 8 files changed, 111 insertions(+), 16 deletions(-) diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index e6d0012c1a4b..c962564c8da0 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -213,6 +213,7 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON endif() set(ONEDNN_AARCH64_USE_ACL "ON") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/") + add_compile_definitions(VLLM_USE_ACL) endif() set(ONEDNN_LIBRARY_TYPE "STATIC") @@ -226,7 +227,7 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON set(ONEDNN_ENABLE_ITT_TASKS "OFF") set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF") set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF") - set(ONEDNN_VERBOSE "OFF") + set(ONEDNN_VERBOSE "ON") set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) FetchContent_MakeAvailable(oneDNN) diff --git a/csrc/cpu/dnnl_helper.cpp b/csrc/cpu/dnnl_helper.cpp index 6def0e061fa9..0f0cc34602b3 100644 --- a/csrc/cpu/dnnl_helper.cpp +++ b/csrc/cpu/dnnl_helper.cpp @@ -137,9 +137,8 @@ DNNLMatMulPrimitiveHandler::DNNLMatMulPrimitiveHandler( } void DNNLMatMulPrimitiveHandler::prepack_weight( - void* original_b_ptr, dnnl::memory::desc b_target_mem_desc) { - dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_, - {b_k_stride_, b_n_stride_}); + void* original_b_ptr, dnnl::memory::desc original_b_md, + dnnl::memory::desc b_target_mem_desc) { dnnl::memory original_weight(original_b_md, default_engine(), original_b_ptr); dnnl::memory packed_weight(b_target_mem_desc, default_engine()); { @@ -250,7 +249,9 @@ W8A8MatMulPrimitiveHandler::W8A8MatMulPrimitiveHandler(const Args& args) if (a_qs_ == QuantizationStrategy::PER_TOKEN) { assert(!use_azp_); }; - prepack_weight(args.b_ptr, + dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_, + {b_k_stride_, b_n_stride_}); + prepack_weight(args.b_ptr, original_b_md, create_primitive_desc( MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL, .use_bias = false, @@ -412,12 +413,25 @@ MatMulPrimitiveHandler::MatMulPrimitiveHandler(const Args& args) assert(ab_type_ == dnnl::memory::data_type::f32 || ab_type_ == dnnl::memory::data_type::bf16 || ab_type_ == dnnl::memory::data_type::f16); - prepack_weight(args.b_ptr, + + dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_, + {b_k_stride_, b_n_stride_}); + + prepack_weight(args.b_ptr, original_b_md, create_primitive_desc( - MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL, - .a_m_stride = DNNL_RUNTIME_DIM_VAL, - .use_bias = false, - .bias_type = dnnl::memory::data_type::undef}, + MSizeCacheKey{ +#ifdef VLLM_USE_ACL + // Arm Compute Library (ACL) backend for oneDNN does + // not support runtime + // dimensions, so we set M to a default value + .a_m_size = 128, + .a_m_stride = b_k_size_, +#else + .a_m_size = DNNL_RUNTIME_DIM_VAL, + .a_m_stride = DNNL_RUNTIME_DIM_VAL, +#endif + .use_bias = false, + .bias_type = dnnl::memory::data_type::undef}, true) .weights_desc()); init_runtime_memory_cache(args); @@ -443,13 +457,31 @@ void MatMulPrimitiveHandler::execute(ExecArgs& args) { c_storage->set_data_handle((void*)args.c_ptr); c_mem_desc->dims[0] = args.a_m_size; +#ifndef VLLM_USE_ACL + // We do not support in ACL backend of oneDNN, we handle bias by: + // 1. copying it into the result tensor + // 2. attaching a fused-sum post-op to the matmul primitive if (args.use_bias) { auto&& [bias_storage, bias_mem_desc] = get_runtime_memory_ptr(2); bias_storage->set_data_handle((void*)args.bias_ptr); } - +#endif dnnl::matmul matmul = get_matmul_cache(args); +// With ACL backend of oneDNN, the required memory format might change when the +// source tensor dims change. This does not really happen in practice, so isn't +// a performance hit, but we need to support it because the API allows for it. +#ifdef VLLM_USE_ACL + auto new_expected_wei_desc = + dnnl::matmul::primitive_desc( + const_cast(matmul.get_primitive_desc())) + .weights_desc(); + if (new_expected_wei_desc != b_target_mem_desc_) { + prepack_weight(memory_cache_[DNNL_ARG_WEIGHTS].get_data_handle(), + b_target_mem_desc_, new_expected_wei_desc); + } +#endif + auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(3); scratchpad_storage->set_data_handle( DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data()); @@ -484,7 +516,13 @@ dnnl::matmul::primitive_desc MatMulPrimitiveHandler::create_primitive_desc( } else { a_md = dnnl::memory::desc({key.a_m_size, b_k_size_}, b_type_, {key.a_m_stride, 1}); +#ifdef VLLM_USE_ACL + // ACL's backend of oneDNN always expects the weight format to be "any" + b_md = dnnl::memory::desc({b_k_size_, b_n_size_}, b_type_, + dnnl::memory::format_tag::any); +#else b_md = b_target_mem_desc_; +#endif } dnnl::memory::desc c_md({key.a_m_size, b_n_size_}, c_type_, dnnl::memory::format_tag::ab); @@ -494,8 +532,18 @@ dnnl::matmul::primitive_desc MatMulPrimitiveHandler::create_primitive_desc( if (key.use_bias) { dnnl::memory::desc bias_md({1, b_n_size_}, key.bias_type, {b_n_size_, 1}); +// Since ACL's matmuls don't support passing a bias_md, we apply the bias +// through a fused-sum post-op +#ifdef VLLM_USE_ACL + dnnl::post_ops post_ops; + post_ops.append_sum(); + attr.set_post_ops(post_ops); + return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md, + attr); +#else return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, bias_md, c_md, attr); +#endif } else { return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md, attr); @@ -511,13 +559,23 @@ void MatMulPrimitiveHandler::init_runtime_memory_cache(const Args& args) { default_engine(), nullptr); set_runtime_memory_ptr(1, memory_cache_[DNNL_ARG_DST].get()); +// ACL matmuls don't support bias_md, so we don't need these +#ifndef VLLM_USE_ACL memory_cache_[DNNL_ARG_BIAS] = dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}}, default_engine(), nullptr); set_runtime_memory_ptr(2, memory_cache_[DNNL_ARG_BIAS].get()); - +#endif memory_cache_[DNNL_ARG_SCRATCHPAD] = dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}}, default_engine(), nullptr); set_runtime_memory_ptr(3, memory_cache_[DNNL_ARG_SCRATCHPAD].get()); } + +bool is_onednn_acl_supported() { +#ifdef VLLM_USE_ACL + return true; +#else + return false; +#endif +} diff --git a/csrc/cpu/dnnl_helper.h b/csrc/cpu/dnnl_helper.h index ad6773d2b9fd..f0cb197d81a3 100644 --- a/csrc/cpu/dnnl_helper.h +++ b/csrc/cpu/dnnl_helper.h @@ -101,7 +101,7 @@ class DNNLMatMulPrimitiveHandler { protected: DNNLMatMulPrimitiveHandler(const Args& args, dnnl::memory::data_type b_type); - void prepack_weight(void* original_b_ptr, + void prepack_weight(void* original_b_ptr, dnnl::memory::desc original_b_md, dnnl::memory::desc b_target_mem_desc); void set_runtime_memory_ptr(size_t index, dnnl_memory* memory_ptr); diff --git a/csrc/cpu/dnnl_kernels.cpp b/csrc/cpu/dnnl_kernels.cpp index 1c42a75bc2d6..6d062c71e767 100644 --- a/csrc/cpu/dnnl_kernels.cpp +++ b/csrc/cpu/dnnl_kernels.cpp @@ -527,21 +527,42 @@ void onednn_mm(torch::Tensor& c, // [M, OC], row-major MatMulPrimitiveHandler* ptr = reinterpret_cast(handler); +// ACL matmuls expect contiguous source tensors +#ifdef VLLM_USE_ACL + torch::Tensor a_contig = a.contiguous(); +#endif + MatMulPrimitiveHandler::ExecArgs exec_args; + +#ifdef VLLM_USE_ACL + exec_args.a_m_size = a_contig.size(0); + exec_args.a_m_stride = a_contig.stride(0); +#else exec_args.a_m_size = a.size(0); exec_args.a_m_stride = a.stride(0); - +#endif VLLM_DISPATCH_FLOATING_TYPES(a.scalar_type(), "onednn_mm", [&] { if (bias.has_value()) { exec_args.use_bias = true; exec_args.bias_type = get_dnnl_type(); +#ifdef VLLM_USE_ACL + // ACL matmuls in oneDNN do not support a bias. + // We handle a matmul with bias by doing: c = bias; c += matmul(a, b) + c.copy_(bias.value()); +#else exec_args.bias_ptr = bias->data_ptr(); +#endif } else { exec_args.use_bias = false; exec_args.bias_type = get_dnnl_type(); exec_args.bias_ptr = nullptr; } +#ifdef VLLM_USE_ACL + exec_args.a_ptr = a_contig.data_ptr(); +#else exec_args.a_ptr = a.data_ptr(); + +#endif exec_args.c_ptr = c.data_ptr(); ptr->execute(exec_args); diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index d279c03e0b59..9df19d1ac392 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -27,6 +27,8 @@ int64_t create_onednn_mm_handler(const torch::Tensor& b, void onednn_mm(torch::Tensor& c, const torch::Tensor& a, const std::optional& bias, int64_t handler); +bool is_onednn_acl_supported(); + void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query, torch::Tensor& kv_cache, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens); @@ -181,6 +183,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "int handler) -> ()"); ops.impl("onednn_mm", torch::kCPU, &onednn_mm); + // Check if oneDNN was built with ACL backend + ops.def("is_onednn_acl_supported() -> bool", &is_onednn_acl_supported); + // Create oneDNN W8A8 handler ops.def( "create_onednn_scaled_mm_handler(Tensor b, Tensor b_scales, ScalarType " diff --git a/setup.py b/setup.py index 5491046991ca..fcd9570beae1 100644 --- a/setup.py +++ b/setup.py @@ -205,6 +205,11 @@ class cmake_build_ext(build_ext): # Make sure we use the nvcc from CUDA_HOME if _is_cuda(): cmake_args += [f'-DCMAKE_CUDA_COMPILER={CUDA_HOME}/bin/nvcc'] + + other_cmake_args = os.environ.get("CMAKE_ARGS") + if other_cmake_args: + cmake_args += other_cmake_args.split() + subprocess.check_call( ['cmake', ext.cmake_lists_dir, *build_tool, *cmake_args], cwd=self.build_temp) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index f07fa1e4e7be..84d96ee3a84d 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1926,6 +1926,10 @@ else: _supports_onednn = False +def is_onednn_acl_supported(): + return torch.ops._C.is_onednn_acl_supported() + + def create_onednn_mm( weight: torch.Tensor, # [K, N] primitive_cache_size: int = 128, diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 96dd58c0e4d2..ac3a604a5a3b 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -165,8 +165,9 @@ def dispatch_cpu_unquantized_gemm( if remove_weight: layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False) - elif (ops._supports_onednn - and current_platform.get_cpu_architecture() == CpuArchEnum.X86): + elif ops._supports_onednn and (current_platform.get_cpu_architecture() + == CpuArchEnum.X86 + or ops.is_onednn_acl_supported()): origin_weight = layer.weight if remove_weight: layer.weight = torch.nn.Parameter(torch.empty(0),