From 57b1ce94f7ac13624a32da90658b36b60a276a60 Mon Sep 17 00:00:00 2001 From: "Li, Jiang" Date: Thu, 4 Sep 2025 14:28:45 +0800 Subject: [PATCH] [CPU] Refactor CPU unquantized linear (#24150) Signed-off-by: jiang1.li --- csrc/cpu/dnnl_helper.cpp | 177 ++++++++++++++++++ csrc/cpu/dnnl_helper.h | 74 ++++++++ csrc/cpu/dnnl_kernels.cpp | 54 ++++++ csrc/cpu/torch_bindings.cpp | 18 ++ tests/kernels/test_onednn.py | 70 +++++++ vllm/_custom_ops.py | 29 +++ vllm/model_executor/layers/linear.py | 25 +-- vllm/model_executor/layers/utils.py | 39 +++- .../layers/vocab_parallel_embedding.py | 6 + 9 files changed, 466 insertions(+), 26 deletions(-) diff --git a/csrc/cpu/dnnl_helper.cpp b/csrc/cpu/dnnl_helper.cpp index f3f00edb36068..6def0e061fa96 100644 --- a/csrc/cpu/dnnl_helper.cpp +++ b/csrc/cpu/dnnl_helper.cpp @@ -22,6 +22,23 @@ void release_dnnl_matmul_handler(int64_t handler) { delete ptr; } +DNNLScratchPadManager::DNNLScratchPadManager() : size_(0), ptr_(nullptr) { + this->realloc(allocation_unit * 128); +} + +void DNNLScratchPadManager::realloc(size_t new_size) { + new_size = round(new_size); + if (new_size > size_) { + ptr_ = std::aligned_alloc(64, new_size); + size_ = new_size; + } +} + +DNNLScratchPadManager* DNNLScratchPadManager::get_dnnl_scratchpad_manager() { + static DNNLScratchPadManager manager; + return &manager; +} + template class DNNLPrimitiveCache { public: @@ -166,6 +183,23 @@ struct hash { hash()(static_cast(val.bias_type)); } }; + +template <> +struct hash { + size_t operator()( + const MatMulPrimitiveHandler::ClassMatmulCacheKey& val) const { + return hash()(val.b_n_size) ^ hash()(val.b_k_size); + } +}; + +template <> +struct hash { + size_t operator()(const MatMulPrimitiveHandler::MSizeCacheKey& val) const { + return hash()(val.a_m_size) ^ + hash()(val.a_m_stride) ^ hash()(val.use_bias) ^ + hash()(static_cast(val.bias_type)); + } +}; } // namespace std bool operator==(const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& l, @@ -181,6 +215,17 @@ bool operator==(const W8A8MatMulPrimitiveHandler::MSizeCacheKey& l, l.bias_type == r.bias_type; } +bool operator==(const MatMulPrimitiveHandler::ClassMatmulCacheKey& l, + const MatMulPrimitiveHandler::ClassMatmulCacheKey& r) { + return l.b_n_size == r.b_n_size && l.b_k_size == r.b_k_size; +} + +bool operator==(const MatMulPrimitiveHandler::MSizeCacheKey& l, + const MatMulPrimitiveHandler::MSizeCacheKey& r) { + return l.a_m_size == r.a_m_size && l.a_m_stride == r.a_m_stride && + l.use_bias == r.use_bias && l.bias_type == r.bias_type; +} + static std::shared_ptr get_w8a8_class_primitive_cache( const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& key, @@ -239,6 +284,11 @@ void W8A8MatMulPrimitiveHandler::execute(ExecArgs& args) { } dnnl::matmul matmul = get_matmul_cache(args); + + auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(5); + scratchpad_storage->set_data_handle( + DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data()); + matmul.execute(default_stream(), memory_cache_); default_stream().wait(); } @@ -257,6 +307,8 @@ dnnl::matmul W8A8MatMulPrimitiveHandler::get_matmul_cache( return m_size_cache_->get_or_create(key, [&]() { dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false); + auto manager = DNNLScratchPadManager::get_dnnl_scratchpad_manager(); + manager->realloc(desc.scratchpad_desc().get_size()); return dnnl::matmul(desc); }); } @@ -300,6 +352,11 @@ void W8A8MatMulPrimitiveHandler::init_runtime_memory_cache(const Args& args) { dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}}, default_engine(), nullptr); set_runtime_memory_ptr(4, memory_cache_[DNNL_ARG_BIAS].get()); + + memory_cache_[DNNL_ARG_SCRATCHPAD] = + dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}}, + default_engine(), nullptr); + set_runtime_memory_ptr(5, memory_cache_[DNNL_ARG_SCRATCHPAD].get()); } dnnl::matmul::primitive_desc W8A8MatMulPrimitiveHandler::create_primitive_desc( @@ -319,6 +376,9 @@ dnnl::matmul::primitive_desc W8A8MatMulPrimitiveHandler::create_primitive_desc( dnnl::memory::format_tag::ab); dnnl::primitive_attr attr; + + attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); + // For PER_TOKEN, scales will be applied in outside epilogue if (a_qs_ == QuantizationStrategy::PER_TENSOR) { attr.set_scales_mask(DNNL_ARG_SRC, 0); @@ -344,3 +404,120 @@ dnnl::matmul::primitive_desc W8A8MatMulPrimitiveHandler::create_primitive_desc( attr); } } + +MatMulPrimitiveHandler::MatMulPrimitiveHandler(const Args& args) + : DNNLMatMulPrimitiveHandler( + static_cast(args), args.ab_type), + m_size_cache_(nullptr) { + assert(ab_type_ == dnnl::memory::data_type::f32 || + ab_type_ == dnnl::memory::data_type::bf16 || + ab_type_ == dnnl::memory::data_type::f16); + prepack_weight(args.b_ptr, + create_primitive_desc( + MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL, + .a_m_stride = DNNL_RUNTIME_DIM_VAL, + .use_bias = false, + .bias_type = dnnl::memory::data_type::undef}, + true) + .weights_desc()); + init_runtime_memory_cache(args); +} + +static std::shared_ptr +get_matul_class_primitive_cache( + const MatMulPrimitiveHandler::ClassMatmulCacheKey& key, + int64_t cache_size) { + static MatMulPrimitiveHandler::ClassMatmulCache cache(128); + assert(cache_size > 0); + return cache.get_or_create(key, [&]() { + return std::make_shared(cache_size); + }); +} + +void MatMulPrimitiveHandler::execute(ExecArgs& args) { + auto&& [a_storage, a_mem_desc] = get_runtime_memory_ptr(0); + auto&& [c_storage, c_mem_desc] = get_runtime_memory_ptr(1); + a_storage->set_data_handle((void*)args.a_ptr); + a_mem_desc->dims[0] = args.a_m_size; + a_mem_desc->format_desc.blocking.strides[0] = args.a_m_stride; + c_storage->set_data_handle((void*)args.c_ptr); + c_mem_desc->dims[0] = args.a_m_size; + + if (args.use_bias) { + auto&& [bias_storage, bias_mem_desc] = get_runtime_memory_ptr(2); + bias_storage->set_data_handle((void*)args.bias_ptr); + } + + dnnl::matmul matmul = get_matmul_cache(args); + + auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(3); + scratchpad_storage->set_data_handle( + DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data()); + + matmul.execute(default_stream(), memory_cache_); + default_stream().wait(); +} + +dnnl::matmul MatMulPrimitiveHandler::get_matmul_cache( + const MSizeCacheKey& key) { + if (m_size_cache_.get() == nullptr) { + ClassMatmulCacheKey key = {.b_n_size = b_n_size_, .b_k_size = b_k_size_}; + m_size_cache_ = get_matul_class_primitive_cache(key, primitive_cache_size_); + } + return m_size_cache_->get_or_create(key, [&]() { + dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false); + auto manager = DNNLScratchPadManager::get_dnnl_scratchpad_manager(); + manager->realloc(desc.scratchpad_desc().get_size()); + return dnnl::matmul(desc); + }); +} + +dnnl::matmul::primitive_desc MatMulPrimitiveHandler::create_primitive_desc( + const MSizeCacheKey& key, bool first_time) { + dnnl::memory::desc a_md; + dnnl::memory::desc b_md; + if (first_time) { + a_md = dnnl::memory::desc({key.a_m_size, b_k_size_}, b_type_, + dnnl::memory::format_tag::ab); + b_md = dnnl::memory::desc({b_k_size_, b_n_size_}, b_type_, + dnnl::memory::format_tag::any); + } else { + a_md = dnnl::memory::desc({key.a_m_size, b_k_size_}, b_type_, + {key.a_m_stride, 1}); + b_md = b_target_mem_desc_; + } + dnnl::memory::desc c_md({key.a_m_size, b_n_size_}, c_type_, + dnnl::memory::format_tag::ab); + + dnnl::primitive_attr attr; + attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); + + if (key.use_bias) { + dnnl::memory::desc bias_md({1, b_n_size_}, key.bias_type, {b_n_size_, 1}); + return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, bias_md, + c_md, attr); + } else { + return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md, + attr); + } +} + +void MatMulPrimitiveHandler::init_runtime_memory_cache(const Args& args) { + memory_cache_[DNNL_ARG_SRC] = dnnl::memory( + {{1, b_k_size_}, b_type_, {b_k_size_, 1}}, default_engine(), nullptr); + set_runtime_memory_ptr(0, memory_cache_[DNNL_ARG_SRC].get()); + memory_cache_[DNNL_ARG_DST] = + dnnl::memory({{1, b_n_size_}, c_type_, dnnl::memory::format_tag::ab}, + default_engine(), nullptr); + set_runtime_memory_ptr(1, memory_cache_[DNNL_ARG_DST].get()); + + memory_cache_[DNNL_ARG_BIAS] = + dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}}, + default_engine(), nullptr); + set_runtime_memory_ptr(2, memory_cache_[DNNL_ARG_BIAS].get()); + + memory_cache_[DNNL_ARG_SCRATCHPAD] = + dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}}, + default_engine(), nullptr); + set_runtime_memory_ptr(3, memory_cache_[DNNL_ARG_SCRATCHPAD].get()); +} diff --git a/csrc/cpu/dnnl_helper.h b/csrc/cpu/dnnl_helper.h index 54ceefced9e98..ad6773d2b9fd6 100644 --- a/csrc/cpu/dnnl_helper.h +++ b/csrc/cpu/dnnl_helper.h @@ -59,6 +59,30 @@ constexpr inline dnnl::memory::data_type get_dnnl_type() { return DNNLType>::type; } +class DNNLScratchPadManager { + public: + static constexpr size_t allocation_unit = 4 * 1024 * 1024; // 4KB + + static DNNLScratchPadManager* get_dnnl_scratchpad_manager(); + + DNNLScratchPadManager(); + + template + T* get_data() { + return reinterpret_cast(ptr_); + } + + static size_t round(size_t size) { + return ((size + allocation_unit - 1) / allocation_unit) * allocation_unit; + } + + void realloc(size_t new_size); + + private: + size_t size_; + void* ptr_; +}; + class DNNLMatMulPrimitiveHandler { public: virtual ~DNNLMatMulPrimitiveHandler() = default; @@ -166,4 +190,54 @@ class W8A8MatMulPrimitiveHandler : public DNNLMatMulPrimitiveHandler { std::shared_ptr m_size_cache_; }; +class MatMulPrimitiveHandler : public DNNLMatMulPrimitiveHandler { + public: + struct Args : public DNNLMatMulPrimitiveHandler::Args { + dnnl::memory::data_type ab_type; + }; + + struct ClassMatmulCacheKey { + dnnl_dim_t b_n_size; + dnnl_dim_t b_k_size; + + friend bool operator==(const ClassMatmulCacheKey& l, + const ClassMatmulCacheKey& r); + }; + + struct MSizeCacheKey { + dnnl_dim_t a_m_size; + dnnl_dim_t a_m_stride; + bool use_bias; + dnnl::memory::data_type bias_type; + + friend bool operator==(const MSizeCacheKey& l, const MSizeCacheKey& r); + }; + + using MSizeCache = DNNLPrimitiveCache; + using ClassMatmulCache = + DNNLPrimitiveCache>; + + struct ExecArgs : public MSizeCacheKey { + const void* a_ptr; + const void* bias_ptr; + void* c_ptr; + }; + + public: + MatMulPrimitiveHandler(const Args& args); + + void execute(ExecArgs& args); + + private: + dnnl::matmul::primitive_desc create_primitive_desc(const MSizeCacheKey& key, + bool first_time); + + void init_runtime_memory_cache(const Args& args); + + dnnl::matmul get_matmul_cache(const MSizeCacheKey& key); + + private: + std::shared_ptr m_size_cache_; +}; + #endif diff --git a/csrc/cpu/dnnl_kernels.cpp b/csrc/cpu/dnnl_kernels.cpp index acc3b9ecde143..1aa99614926df 100644 --- a/csrc/cpu/dnnl_kernels.cpp +++ b/csrc/cpu/dnnl_kernels.cpp @@ -379,6 +379,7 @@ void onednn_scaled_mm( exec_args.a_ptr = a.data_ptr(); exec_args.a_m_size = a.size(0); exec_args.bias_ptr = nullptr; + exec_args.bias_type = get_dnnl_type(); exec_args.use_bias = false; exec_args.a_scales_ptr = nullptr; exec_args.a_zero_points_ptr = nullptr; @@ -492,3 +493,56 @@ void dynamic_scaled_int8_quant( } }); } + +int64_t create_onednn_mm_handler(const torch::Tensor& b, + int64_t primitive_cache_size) { + TORCH_CHECK(b.dim() == 2); + + MatMulPrimitiveHandler::Args args; + args.primitive_cache_size = primitive_cache_size; + + args.b_k_size = b.size(0); + args.b_k_stride = b.stride(0); + args.b_n_size = b.size(1); + args.b_n_stride = b.stride(1); + args.b_ptr = b.data_ptr(); + + VLLM_DISPATCH_FLOATING_TYPES(b.scalar_type(), "create_onednn_mm_handler", + [&] { + args.c_type = get_dnnl_type(); + args.ab_type = get_dnnl_type(); + }); + + return reinterpret_cast(new MatMulPrimitiveHandler(args)); +} + +void onednn_mm(torch::Tensor& c, // [M, OC], row-major + const torch::Tensor& a, // [M, IC], row-major + const std::optional& bias, int64_t handler) { + CPU_KERNEL_GUARD_IN(onednn_mm) + TORCH_CHECK(a.dim() == 2); + TORCH_CHECK(a.stride(-1) == 1); + TORCH_CHECK(c.is_contiguous()); + MatMulPrimitiveHandler* ptr = + reinterpret_cast(handler); + + MatMulPrimitiveHandler::ExecArgs exec_args; + exec_args.a_m_size = a.size(0); + exec_args.a_m_stride = a.stride(0); + + VLLM_DISPATCH_FLOATING_TYPES(a.scalar_type(), "onednn_mm", [&] { + if (bias.has_value()) { + exec_args.use_bias = true; + exec_args.bias_type = get_dnnl_type(); + exec_args.bias_ptr = bias->data_ptr(); + } else { + exec_args.use_bias = false; + exec_args.bias_type = get_dnnl_type(); + exec_args.bias_ptr = nullptr; + } + exec_args.a_ptr = a.data_ptr(); + exec_args.c_ptr = c.data_ptr(); + + ptr->execute(exec_args); + }); +} diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index c9f426bdf618a..98c3ebc5a75f8 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -21,6 +21,12 @@ void onednn_scaled_mm(torch::Tensor& c, const torch::Tensor& a, const std::optional& bias, int64_t handler); +int64_t create_onednn_mm_handler(const torch::Tensor& b, + int64_t primitive_cache_size); + +void onednn_mm(torch::Tensor& c, const torch::Tensor& a, + const std::optional& bias, int64_t handler); + void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query, torch::Tensor& kv_cache, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens); @@ -153,6 +159,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("release_dnnl_matmul_handler(int handler) -> ()", &release_dnnl_matmul_handler); + // Create oneDNN GEMM handler + ops.def( + "create_onednn_mm_handler(Tensor b, int " + "primitive_cache_size) -> int", + &create_onednn_mm_handler); + + // oneDNN GEMM + ops.def( + "onednn_mm(Tensor! c, Tensor a, Tensor? bias, " + "int handler) -> ()"); + ops.impl("onednn_mm", torch::kCPU, &onednn_mm); + // Create oneDNN W8A8 handler ops.def( "create_onednn_scaled_mm_handler(Tensor b, Tensor b_scales, ScalarType " diff --git a/tests/kernels/test_onednn.py b/tests/kernels/test_onednn.py index 17692384ac9a9..37772464a209b 100644 --- a/tests/kernels/test_onednn.py +++ b/tests/kernels/test_onednn.py @@ -111,6 +111,49 @@ def onednn_int8_gemm_test_helper(primitive_cache_size: int, torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0) +def onednn_gemm_test_helper(primitive_cache_size: int, + m: int, + n: int, + k: int, + use_bias: bool, + use_stride: bool, + dtype: torch.dtype = torch.bfloat16, + device: str = "cpu"): + if use_stride: + a = torch.rand((m, 2 * k), dtype=dtype, device=device) * 1.5 + a = a[:, :k] + else: + a = torch.rand((m, k), dtype=dtype, device=device) * 1.5 + + b = torch.rand((n, k), dtype=dtype, device=device) * 1.5 + + if use_bias: + bias = torch.rand((n, ), device=device, dtype=dtype) * 5 + bias_f32 = bias.float() + else: + bias = None + bias_f32 = None + + handler = ops.create_onednn_mm( + b.t(), + primitive_cache_size, + ) + + out = ops.onednn_mm(handler, a, bias) + baseline = torch.nn.functional.linear(a.float(), b.float(), + bias_f32).to(dtype=a.dtype) + + torch.testing.assert_close(out, baseline) + + if use_bias: + # To test runtime bias setting + out = ops.onednn_mm(handler, a, None) + baseline = torch.nn.functional.linear(a.float(), b.float(), + None).to(dtype=a.dtype) + + torch.testing.assert_close(out, baseline) + + @pytest.mark.parametrize("n,k", NK_FACTORS) @pytest.mark.parametrize("m_list", M_FACTORS) @pytest.mark.parametrize("per_tensor_a_scale", [True, False]) @@ -142,3 +185,30 @@ def test_onednn_int8_scaled_gemm( use_azp=use_azp, out_dtype=output_type, ) + + +@pytest.mark.parametrize("n,k", NK_FACTORS) +@pytest.mark.parametrize("m_list", M_FACTORS) +@pytest.mark.parametrize("use_bias", [True, False]) +@pytest.mark.parametrize("use_stride", [True, False]) +@pytest.mark.parametrize("dtype", DTYPE) +@pytest.mark.parametrize("primitive_cache_size", CACHE_SIZES) +def test_onednn_gemm( + n: int, + k: int, + m_list: tuple[int], + use_bias: bool, + use_stride: bool, + dtype: torch.dtype, + primitive_cache_size: int, +): + for m in m_list: + onednn_gemm_test_helper( + primitive_cache_size=primitive_cache_size, + m=m, + n=n, + k=k, + use_bias=use_bias, + use_stride=use_stride, + dtype=dtype, + ) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 340d6e1164e4f..bb67d5790aaaa 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1928,6 +1928,35 @@ class CPUDNNLGEMMHandler: torch.ops._C.release_dnnl_matmul_handler(self.handler) +if hasattr(torch.ops._C, "create_onednn_mm_handler"): + _supports_onednn = True +else: + _supports_onednn = False + + +def create_onednn_mm( + weight: torch.Tensor, # [K, N] + primitive_cache_size: int = 128, +) -> CPUDNNLGEMMHandler: + handler = CPUDNNLGEMMHandler() + handler.k, handler.n = weight.size() + handler.handler = torch.ops._C.create_onednn_mm_handler( + weight, primitive_cache_size) + return handler + + +def onednn_mm( + dnnl_handler: CPUDNNLGEMMHandler, + x: torch.Tensor, + bias: Optional[torch.Tensor], +) -> torch.Tensor: + output = torch.empty((*x.shape[0:-1], dnnl_handler.n), dtype=x.dtype) + torch.ops._C.onednn_mm(output, x.reshape(-1, dnnl_handler.k), bias, + dnnl_handler.handler) + + return output + + def create_onednn_scaled_mm( weight: torch.Tensor, # [K, N] weight_scales: torch.Tensor, diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index f24c87dbf4509..1224b94d56e06 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -9,7 +9,6 @@ import torch import torch.nn as nn from torch.nn.parameter import Parameter, UninitializedParameter -from vllm import envs from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, @@ -200,26 +199,10 @@ class UnquantizedLinearMethod(LinearMethodBase): set_weight_attrs(weight, extra_weight_attrs) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # special postprocessing for CPU SGL - if current_platform.is_cpu() and envs.VLLM_CPU_SGL_KERNEL: - from vllm.model_executor.layers.utils import check_cpu_sgl_kernel - N, K = layer.weight.size() - dtype = layer.weight.dtype - if check_cpu_sgl_kernel(N, K, dtype): - packed_weight = torch.ops._C.convert_weight_packed( - layer.weight) - assert packed_weight.size() == layer.weight.size() - layer.weight.copy_(packed_weight) - if layer.bias is not None: - layer.bias = Parameter(layer.bias.to(torch.float32), - requires_grad=False) - layer.use_cpu_sgl = True - else: - logger.warning( - "CPU SGL kernels require Intel AMX support," - " bf16/fp16/int8 weight, IC and OC are divisible by " - "32 and 16.") - layer.use_cpu_sgl = False + if current_platform.is_cpu(): + from vllm.model_executor.layers.utils import ( + dispatch_cpu_unquantized_gemm) + dispatch_cpu_unquantized_gemm(layer, remove_weight=True) def apply(self, layer: torch.nn.Module, diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 2897f75b3129e..d2b135c1e4d4e 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -142,20 +142,49 @@ direct_register_custom_op( ) -def check_cpu_sgl_kernel(n: int, k: int, dtype: torch.dtype): +def check_cpu_sgl_kernel(n: int, k: int, dtype: torch.dtype) -> bool: return (torch._C._cpu._is_amx_tile_supported() and (dtype in (torch.bfloat16, torch.int8)) and k % 32 == 0 and n % 16 == 0) +def dispatch_cpu_unquantized_gemm( + layer: torch.nn.Module, + remove_weight: bool, +) -> None: + N, K = layer.weight.size() + dtype = layer.weight.dtype + if envs.VLLM_CPU_SGL_KERNEL and check_cpu_sgl_kernel(N, K, dtype): + packed_weight = torch.ops._C.convert_weight_packed(layer.weight) + if getattr(layer, "bias", None) is not None: + bias_f32 = layer.bias.to(torch.float32) + else: + bias_f32 = None + layer.cpu_linear = ( + lambda x, weight, bias: torch.ops._C.weight_packed_linear( + x, packed_weight, bias_f32 + if bias is not None else None, True)) + if remove_weight: + layer.weight = torch.nn.Parameter(torch.empty(0), + requires_grad=False) + elif ops._supports_onednn: + origin_weight = layer.weight + if remove_weight: + layer.weight = torch.nn.Parameter(torch.empty(0), + requires_grad=False) + handler = ops.create_onednn_mm(origin_weight.t(), 32) + layer.cpu_linear = lambda x, weight, bias: ops.onednn_mm( + handler, x, bias) + else: + layer.cpu_linear = lambda x, weight, bias: torch.nn.functional.linear( + x, weight, bias) + + def cpu_unquantized_gemm(layer: torch.nn.Module, x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None): - if getattr(layer, "use_cpu_sgl", False): - return torch.ops._C.weight_packed_linear(x, weight, bias, True) - else: - return torch.nn.functional.linear(x, weight, bias) + return layer.cpu_linear(x, weight, bias) def dispatch_unquantized_gemm() -> Callable[..., torch.Tensor]: diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 9f223998e554f..c92a7978195bc 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -40,6 +40,12 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase): layer.register_parameter("weight", weight) set_weight_attrs(weight, extra_weight_attrs) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if current_platform.is_cpu(): + from vllm.model_executor.layers.utils import ( + dispatch_cpu_unquantized_gemm) + dispatch_cpu_unquantized_gemm(layer, remove_weight=False) + def apply(self, layer: torch.nn.Module, x: torch.Tensor,