mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:44:56 +08:00
[cpu][perf] Accelerate unquantized-linear for AArch64 through oneDNN/ACL and weight prepack (#25948)
Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com> Co-authored-by: Li, Jiang <jiang1.li@intel.com>
This commit is contained in:
parent
2f7dbc9b42
commit
9705fba7b7
@ -213,6 +213,7 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON
|
|||||||
endif()
|
endif()
|
||||||
set(ONEDNN_AARCH64_USE_ACL "ON")
|
set(ONEDNN_AARCH64_USE_ACL "ON")
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/")
|
||||||
|
add_compile_definitions(VLLM_USE_ACL)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
set(ONEDNN_LIBRARY_TYPE "STATIC")
|
set(ONEDNN_LIBRARY_TYPE "STATIC")
|
||||||
@ -226,7 +227,7 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON
|
|||||||
set(ONEDNN_ENABLE_ITT_TASKS "OFF")
|
set(ONEDNN_ENABLE_ITT_TASKS "OFF")
|
||||||
set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF")
|
set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF")
|
||||||
set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF")
|
set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF")
|
||||||
set(ONEDNN_VERBOSE "OFF")
|
set(ONEDNN_VERBOSE "ON")
|
||||||
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
|
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
|
||||||
|
|
||||||
FetchContent_MakeAvailable(oneDNN)
|
FetchContent_MakeAvailable(oneDNN)
|
||||||
|
|||||||
@ -137,9 +137,8 @@ DNNLMatMulPrimitiveHandler::DNNLMatMulPrimitiveHandler(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void DNNLMatMulPrimitiveHandler::prepack_weight(
|
void DNNLMatMulPrimitiveHandler::prepack_weight(
|
||||||
void* original_b_ptr, dnnl::memory::desc b_target_mem_desc) {
|
void* original_b_ptr, dnnl::memory::desc original_b_md,
|
||||||
dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_,
|
dnnl::memory::desc b_target_mem_desc) {
|
||||||
{b_k_stride_, b_n_stride_});
|
|
||||||
dnnl::memory original_weight(original_b_md, default_engine(), original_b_ptr);
|
dnnl::memory original_weight(original_b_md, default_engine(), original_b_ptr);
|
||||||
dnnl::memory packed_weight(b_target_mem_desc, default_engine());
|
dnnl::memory packed_weight(b_target_mem_desc, default_engine());
|
||||||
{
|
{
|
||||||
@ -250,7 +249,9 @@ W8A8MatMulPrimitiveHandler::W8A8MatMulPrimitiveHandler(const Args& args)
|
|||||||
if (a_qs_ == QuantizationStrategy::PER_TOKEN) {
|
if (a_qs_ == QuantizationStrategy::PER_TOKEN) {
|
||||||
assert(!use_azp_);
|
assert(!use_azp_);
|
||||||
};
|
};
|
||||||
prepack_weight(args.b_ptr,
|
dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_,
|
||||||
|
{b_k_stride_, b_n_stride_});
|
||||||
|
prepack_weight(args.b_ptr, original_b_md,
|
||||||
create_primitive_desc(
|
create_primitive_desc(
|
||||||
MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL,
|
MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL,
|
||||||
.use_bias = false,
|
.use_bias = false,
|
||||||
@ -412,12 +413,25 @@ MatMulPrimitiveHandler::MatMulPrimitiveHandler(const Args& args)
|
|||||||
assert(ab_type_ == dnnl::memory::data_type::f32 ||
|
assert(ab_type_ == dnnl::memory::data_type::f32 ||
|
||||||
ab_type_ == dnnl::memory::data_type::bf16 ||
|
ab_type_ == dnnl::memory::data_type::bf16 ||
|
||||||
ab_type_ == dnnl::memory::data_type::f16);
|
ab_type_ == dnnl::memory::data_type::f16);
|
||||||
prepack_weight(args.b_ptr,
|
|
||||||
|
dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_,
|
||||||
|
{b_k_stride_, b_n_stride_});
|
||||||
|
|
||||||
|
prepack_weight(args.b_ptr, original_b_md,
|
||||||
create_primitive_desc(
|
create_primitive_desc(
|
||||||
MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL,
|
MSizeCacheKey{
|
||||||
.a_m_stride = DNNL_RUNTIME_DIM_VAL,
|
#ifdef VLLM_USE_ACL
|
||||||
.use_bias = false,
|
// Arm Compute Library (ACL) backend for oneDNN does
|
||||||
.bias_type = dnnl::memory::data_type::undef},
|
// not support runtime
|
||||||
|
// dimensions, so we set M to a default value
|
||||||
|
.a_m_size = 128,
|
||||||
|
.a_m_stride = b_k_size_,
|
||||||
|
#else
|
||||||
|
.a_m_size = DNNL_RUNTIME_DIM_VAL,
|
||||||
|
.a_m_stride = DNNL_RUNTIME_DIM_VAL,
|
||||||
|
#endif
|
||||||
|
.use_bias = false,
|
||||||
|
.bias_type = dnnl::memory::data_type::undef},
|
||||||
true)
|
true)
|
||||||
.weights_desc());
|
.weights_desc());
|
||||||
init_runtime_memory_cache(args);
|
init_runtime_memory_cache(args);
|
||||||
@ -443,13 +457,31 @@ void MatMulPrimitiveHandler::execute(ExecArgs& args) {
|
|||||||
c_storage->set_data_handle((void*)args.c_ptr);
|
c_storage->set_data_handle((void*)args.c_ptr);
|
||||||
c_mem_desc->dims[0] = args.a_m_size;
|
c_mem_desc->dims[0] = args.a_m_size;
|
||||||
|
|
||||||
|
#ifndef VLLM_USE_ACL
|
||||||
|
// We do not support in ACL backend of oneDNN, we handle bias by:
|
||||||
|
// 1. copying it into the result tensor
|
||||||
|
// 2. attaching a fused-sum post-op to the matmul primitive
|
||||||
if (args.use_bias) {
|
if (args.use_bias) {
|
||||||
auto&& [bias_storage, bias_mem_desc] = get_runtime_memory_ptr(2);
|
auto&& [bias_storage, bias_mem_desc] = get_runtime_memory_ptr(2);
|
||||||
bias_storage->set_data_handle((void*)args.bias_ptr);
|
bias_storage->set_data_handle((void*)args.bias_ptr);
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
dnnl::matmul matmul = get_matmul_cache(args);
|
dnnl::matmul matmul = get_matmul_cache(args);
|
||||||
|
|
||||||
|
// With ACL backend of oneDNN, the required memory format might change when the
|
||||||
|
// source tensor dims change. This does not really happen in practice, so isn't
|
||||||
|
// a performance hit, but we need to support it because the API allows for it.
|
||||||
|
#ifdef VLLM_USE_ACL
|
||||||
|
auto new_expected_wei_desc =
|
||||||
|
dnnl::matmul::primitive_desc(
|
||||||
|
const_cast<dnnl_primitive_desc_t>(matmul.get_primitive_desc()))
|
||||||
|
.weights_desc();
|
||||||
|
if (new_expected_wei_desc != b_target_mem_desc_) {
|
||||||
|
prepack_weight(memory_cache_[DNNL_ARG_WEIGHTS].get_data_handle(),
|
||||||
|
b_target_mem_desc_, new_expected_wei_desc);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(3);
|
auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(3);
|
||||||
scratchpad_storage->set_data_handle(
|
scratchpad_storage->set_data_handle(
|
||||||
DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data<void>());
|
DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data<void>());
|
||||||
@ -484,7 +516,13 @@ dnnl::matmul::primitive_desc MatMulPrimitiveHandler::create_primitive_desc(
|
|||||||
} else {
|
} else {
|
||||||
a_md = dnnl::memory::desc({key.a_m_size, b_k_size_}, b_type_,
|
a_md = dnnl::memory::desc({key.a_m_size, b_k_size_}, b_type_,
|
||||||
{key.a_m_stride, 1});
|
{key.a_m_stride, 1});
|
||||||
|
#ifdef VLLM_USE_ACL
|
||||||
|
// ACL's backend of oneDNN always expects the weight format to be "any"
|
||||||
|
b_md = dnnl::memory::desc({b_k_size_, b_n_size_}, b_type_,
|
||||||
|
dnnl::memory::format_tag::any);
|
||||||
|
#else
|
||||||
b_md = b_target_mem_desc_;
|
b_md = b_target_mem_desc_;
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
dnnl::memory::desc c_md({key.a_m_size, b_n_size_}, c_type_,
|
dnnl::memory::desc c_md({key.a_m_size, b_n_size_}, c_type_,
|
||||||
dnnl::memory::format_tag::ab);
|
dnnl::memory::format_tag::ab);
|
||||||
@ -494,8 +532,18 @@ dnnl::matmul::primitive_desc MatMulPrimitiveHandler::create_primitive_desc(
|
|||||||
|
|
||||||
if (key.use_bias) {
|
if (key.use_bias) {
|
||||||
dnnl::memory::desc bias_md({1, b_n_size_}, key.bias_type, {b_n_size_, 1});
|
dnnl::memory::desc bias_md({1, b_n_size_}, key.bias_type, {b_n_size_, 1});
|
||||||
|
// Since ACL's matmuls don't support passing a bias_md, we apply the bias
|
||||||
|
// through a fused-sum post-op
|
||||||
|
#ifdef VLLM_USE_ACL
|
||||||
|
dnnl::post_ops post_ops;
|
||||||
|
post_ops.append_sum();
|
||||||
|
attr.set_post_ops(post_ops);
|
||||||
|
return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md,
|
||||||
|
attr);
|
||||||
|
#else
|
||||||
return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, bias_md,
|
return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, bias_md,
|
||||||
c_md, attr);
|
c_md, attr);
|
||||||
|
#endif
|
||||||
} else {
|
} else {
|
||||||
return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md,
|
return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md,
|
||||||
attr);
|
attr);
|
||||||
@ -511,13 +559,23 @@ void MatMulPrimitiveHandler::init_runtime_memory_cache(const Args& args) {
|
|||||||
default_engine(), nullptr);
|
default_engine(), nullptr);
|
||||||
set_runtime_memory_ptr(1, memory_cache_[DNNL_ARG_DST].get());
|
set_runtime_memory_ptr(1, memory_cache_[DNNL_ARG_DST].get());
|
||||||
|
|
||||||
|
// ACL matmuls don't support bias_md, so we don't need these
|
||||||
|
#ifndef VLLM_USE_ACL
|
||||||
memory_cache_[DNNL_ARG_BIAS] =
|
memory_cache_[DNNL_ARG_BIAS] =
|
||||||
dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
|
dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
|
||||||
default_engine(), nullptr);
|
default_engine(), nullptr);
|
||||||
set_runtime_memory_ptr(2, memory_cache_[DNNL_ARG_BIAS].get());
|
set_runtime_memory_ptr(2, memory_cache_[DNNL_ARG_BIAS].get());
|
||||||
|
#endif
|
||||||
memory_cache_[DNNL_ARG_SCRATCHPAD] =
|
memory_cache_[DNNL_ARG_SCRATCHPAD] =
|
||||||
dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
|
dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
|
||||||
default_engine(), nullptr);
|
default_engine(), nullptr);
|
||||||
set_runtime_memory_ptr(3, memory_cache_[DNNL_ARG_SCRATCHPAD].get());
|
set_runtime_memory_ptr(3, memory_cache_[DNNL_ARG_SCRATCHPAD].get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool is_onednn_acl_supported() {
|
||||||
|
#ifdef VLLM_USE_ACL
|
||||||
|
return true;
|
||||||
|
#else
|
||||||
|
return false;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|||||||
@ -101,7 +101,7 @@ class DNNLMatMulPrimitiveHandler {
|
|||||||
protected:
|
protected:
|
||||||
DNNLMatMulPrimitiveHandler(const Args& args, dnnl::memory::data_type b_type);
|
DNNLMatMulPrimitiveHandler(const Args& args, dnnl::memory::data_type b_type);
|
||||||
|
|
||||||
void prepack_weight(void* original_b_ptr,
|
void prepack_weight(void* original_b_ptr, dnnl::memory::desc original_b_md,
|
||||||
dnnl::memory::desc b_target_mem_desc);
|
dnnl::memory::desc b_target_mem_desc);
|
||||||
|
|
||||||
void set_runtime_memory_ptr(size_t index, dnnl_memory* memory_ptr);
|
void set_runtime_memory_ptr(size_t index, dnnl_memory* memory_ptr);
|
||||||
|
|||||||
@ -527,21 +527,42 @@ void onednn_mm(torch::Tensor& c, // [M, OC], row-major
|
|||||||
MatMulPrimitiveHandler* ptr =
|
MatMulPrimitiveHandler* ptr =
|
||||||
reinterpret_cast<MatMulPrimitiveHandler*>(handler);
|
reinterpret_cast<MatMulPrimitiveHandler*>(handler);
|
||||||
|
|
||||||
|
// ACL matmuls expect contiguous source tensors
|
||||||
|
#ifdef VLLM_USE_ACL
|
||||||
|
torch::Tensor a_contig = a.contiguous();
|
||||||
|
#endif
|
||||||
|
|
||||||
MatMulPrimitiveHandler::ExecArgs exec_args;
|
MatMulPrimitiveHandler::ExecArgs exec_args;
|
||||||
|
|
||||||
|
#ifdef VLLM_USE_ACL
|
||||||
|
exec_args.a_m_size = a_contig.size(0);
|
||||||
|
exec_args.a_m_stride = a_contig.stride(0);
|
||||||
|
#else
|
||||||
exec_args.a_m_size = a.size(0);
|
exec_args.a_m_size = a.size(0);
|
||||||
exec_args.a_m_stride = a.stride(0);
|
exec_args.a_m_stride = a.stride(0);
|
||||||
|
#endif
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(a.scalar_type(), "onednn_mm", [&] {
|
VLLM_DISPATCH_FLOATING_TYPES(a.scalar_type(), "onednn_mm", [&] {
|
||||||
if (bias.has_value()) {
|
if (bias.has_value()) {
|
||||||
exec_args.use_bias = true;
|
exec_args.use_bias = true;
|
||||||
exec_args.bias_type = get_dnnl_type<scalar_t>();
|
exec_args.bias_type = get_dnnl_type<scalar_t>();
|
||||||
|
#ifdef VLLM_USE_ACL
|
||||||
|
// ACL matmuls in oneDNN do not support a bias.
|
||||||
|
// We handle a matmul with bias by doing: c = bias; c += matmul(a, b)
|
||||||
|
c.copy_(bias.value());
|
||||||
|
#else
|
||||||
exec_args.bias_ptr = bias->data_ptr<scalar_t>();
|
exec_args.bias_ptr = bias->data_ptr<scalar_t>();
|
||||||
|
#endif
|
||||||
} else {
|
} else {
|
||||||
exec_args.use_bias = false;
|
exec_args.use_bias = false;
|
||||||
exec_args.bias_type = get_dnnl_type<void>();
|
exec_args.bias_type = get_dnnl_type<void>();
|
||||||
exec_args.bias_ptr = nullptr;
|
exec_args.bias_ptr = nullptr;
|
||||||
}
|
}
|
||||||
|
#ifdef VLLM_USE_ACL
|
||||||
|
exec_args.a_ptr = a_contig.data_ptr<scalar_t>();
|
||||||
|
#else
|
||||||
exec_args.a_ptr = a.data_ptr<scalar_t>();
|
exec_args.a_ptr = a.data_ptr<scalar_t>();
|
||||||
|
|
||||||
|
#endif
|
||||||
exec_args.c_ptr = c.data_ptr<scalar_t>();
|
exec_args.c_ptr = c.data_ptr<scalar_t>();
|
||||||
|
|
||||||
ptr->execute(exec_args);
|
ptr->execute(exec_args);
|
||||||
|
|||||||
@ -27,6 +27,8 @@ int64_t create_onednn_mm_handler(const torch::Tensor& b,
|
|||||||
void onednn_mm(torch::Tensor& c, const torch::Tensor& a,
|
void onednn_mm(torch::Tensor& c, const torch::Tensor& a,
|
||||||
const std::optional<torch::Tensor>& bias, int64_t handler);
|
const std::optional<torch::Tensor>& bias, int64_t handler);
|
||||||
|
|
||||||
|
bool is_onednn_acl_supported();
|
||||||
|
|
||||||
void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
|
void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
|
||||||
torch::Tensor& kv_cache, double scale,
|
torch::Tensor& kv_cache, double scale,
|
||||||
torch::Tensor& block_tables, torch::Tensor& seq_lens);
|
torch::Tensor& block_tables, torch::Tensor& seq_lens);
|
||||||
@ -181,6 +183,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
"int handler) -> ()");
|
"int handler) -> ()");
|
||||||
ops.impl("onednn_mm", torch::kCPU, &onednn_mm);
|
ops.impl("onednn_mm", torch::kCPU, &onednn_mm);
|
||||||
|
|
||||||
|
// Check if oneDNN was built with ACL backend
|
||||||
|
ops.def("is_onednn_acl_supported() -> bool", &is_onednn_acl_supported);
|
||||||
|
|
||||||
// Create oneDNN W8A8 handler
|
// Create oneDNN W8A8 handler
|
||||||
ops.def(
|
ops.def(
|
||||||
"create_onednn_scaled_mm_handler(Tensor b, Tensor b_scales, ScalarType "
|
"create_onednn_scaled_mm_handler(Tensor b, Tensor b_scales, ScalarType "
|
||||||
|
|||||||
5
setup.py
5
setup.py
@ -205,6 +205,11 @@ class cmake_build_ext(build_ext):
|
|||||||
# Make sure we use the nvcc from CUDA_HOME
|
# Make sure we use the nvcc from CUDA_HOME
|
||||||
if _is_cuda():
|
if _is_cuda():
|
||||||
cmake_args += [f'-DCMAKE_CUDA_COMPILER={CUDA_HOME}/bin/nvcc']
|
cmake_args += [f'-DCMAKE_CUDA_COMPILER={CUDA_HOME}/bin/nvcc']
|
||||||
|
|
||||||
|
other_cmake_args = os.environ.get("CMAKE_ARGS")
|
||||||
|
if other_cmake_args:
|
||||||
|
cmake_args += other_cmake_args.split()
|
||||||
|
|
||||||
subprocess.check_call(
|
subprocess.check_call(
|
||||||
['cmake', ext.cmake_lists_dir, *build_tool, *cmake_args],
|
['cmake', ext.cmake_lists_dir, *build_tool, *cmake_args],
|
||||||
cwd=self.build_temp)
|
cwd=self.build_temp)
|
||||||
|
|||||||
@ -1926,6 +1926,10 @@ else:
|
|||||||
_supports_onednn = False
|
_supports_onednn = False
|
||||||
|
|
||||||
|
|
||||||
|
def is_onednn_acl_supported():
|
||||||
|
return torch.ops._C.is_onednn_acl_supported()
|
||||||
|
|
||||||
|
|
||||||
def create_onednn_mm(
|
def create_onednn_mm(
|
||||||
weight: torch.Tensor, # [K, N]
|
weight: torch.Tensor, # [K, N]
|
||||||
primitive_cache_size: int = 128,
|
primitive_cache_size: int = 128,
|
||||||
|
|||||||
@ -165,8 +165,9 @@ def dispatch_cpu_unquantized_gemm(
|
|||||||
if remove_weight:
|
if remove_weight:
|
||||||
layer.weight = torch.nn.Parameter(torch.empty(0),
|
layer.weight = torch.nn.Parameter(torch.empty(0),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
elif (ops._supports_onednn
|
elif ops._supports_onednn and (current_platform.get_cpu_architecture()
|
||||||
and current_platform.get_cpu_architecture() == CpuArchEnum.X86):
|
== CpuArchEnum.X86
|
||||||
|
or ops.is_onednn_acl_supported()):
|
||||||
origin_weight = layer.weight
|
origin_weight = layer.weight
|
||||||
if remove_weight:
|
if remove_weight:
|
||||||
layer.weight = torch.nn.Parameter(torch.empty(0),
|
layer.weight = torch.nn.Parameter(torch.empty(0),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user