From e0c6f556e85053059c74ab6b5cee396baf3b4316 Mon Sep 17 00:00:00 2001 From: Yanming W Date: Thu, 23 Nov 2023 16:31:19 -0800 Subject: [PATCH] [Build] Avoid building too many extensions (#1624) --- .../kernels/benchmark_paged_attention.py | 6 +- csrc/activation.cpp | 28 ------- csrc/attention.cpp | 42 ---------- csrc/{cache.cpp => cache.h} | 19 ----- csrc/cuda_utils.cpp | 13 --- csrc/cuda_utils.h | 5 ++ csrc/layernorm.cpp | 24 ------ csrc/ops.h | 75 +++++++++++++++++ csrc/pos_encoding.cpp | 16 ---- csrc/pybind.cpp | 80 ++++++++++++++++++ csrc/quantization.cpp | 19 ----- setup.py | 82 +++---------------- tests/kernels/test_activation.py | 8 +- tests/kernels/test_attention.py | 6 +- tests/kernels/test_cache.py | 2 +- tests/kernels/test_layernorm.py | 4 +- tests/kernels/test_pos_encoding.py | 4 +- vllm/model_executor/layers/activation.py | 8 +- vllm/model_executor/layers/attention.py | 8 +- vllm/model_executor/layers/layernorm.py | 6 +- .../model_executor/layers/quantization/awq.py | 5 +- .../layers/quantization/squeezellm.py | 5 +- .../model_executor/layers/rotary_embedding.py | 9 +- vllm/utils.py | 2 +- vllm/worker/cache_engine.py | 2 +- 25 files changed, 206 insertions(+), 272 deletions(-) delete mode 100644 csrc/activation.cpp delete mode 100644 csrc/attention.cpp rename csrc/{cache.cpp => cache.h} (58%) delete mode 100644 csrc/cuda_utils.cpp create mode 100644 csrc/cuda_utils.h delete mode 100644 csrc/layernorm.cpp create mode 100644 csrc/ops.h delete mode 100644 csrc/pos_encoding.cpp create mode 100644 csrc/pybind.cpp delete mode 100644 csrc/quantization.cpp diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 0ef8030767677..91fcf5340298a 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -4,7 +4,7 @@ import time import torch -from vllm import attention_ops +from vllm._C import ops NUM_BLOCKS = 1024 PARTITION_SIZE = 512 @@ -98,7 +98,7 @@ def main( for _ in range(num_iters): if version == "v1": - attention_ops.paged_attention_v1( + ops.paged_attention_v1( output, query, key_cache, @@ -112,7 +112,7 @@ def main( alibi_slopes, ) elif version == "v2": - attention_ops.paged_attention_v2( + ops.paged_attention_v2( output, exp_sums, max_logits, diff --git a/csrc/activation.cpp b/csrc/activation.cpp deleted file mode 100644 index c100f89ac7377..0000000000000 --- a/csrc/activation.cpp +++ /dev/null @@ -1,28 +0,0 @@ -#include - -void silu_and_mul( - torch::Tensor& out, - torch::Tensor& input); - -void gelu_new( - torch::Tensor& out, - torch::Tensor& input); - -void gelu_fast( - torch::Tensor& out, - torch::Tensor& input); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def( - "silu_and_mul", - &silu_and_mul, - "Activation function used in SwiGLU."); - m.def( - "gelu_new", - &gelu_new, - "GELU implementation used in GPT-2."); - m.def( - "gelu_fast", - &gelu_fast, - "Approximate GELU implementation."); -} diff --git a/csrc/attention.cpp b/csrc/attention.cpp deleted file mode 100644 index bd93fd71b733d..0000000000000 --- a/csrc/attention.cpp +++ /dev/null @@ -1,42 +0,0 @@ -#include -#include - -void paged_attention_v1( - torch::Tensor& out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - torch::Tensor& head_mapping, - float scale, - torch::Tensor& block_tables, - torch::Tensor& context_lens, - int block_size, - int max_context_len, - const c10::optional& alibi_slopes); - -void paged_attention_v2( - torch::Tensor& out, - torch::Tensor& exp_sums, - torch::Tensor& max_logits, - torch::Tensor& tmp_out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - torch::Tensor& head_mapping, - float scale, - torch::Tensor& block_tables, - torch::Tensor& context_lens, - int block_size, - int max_context_len, - const c10::optional& alibi_slopes); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def( - "paged_attention_v1", - &paged_attention_v1, - "Compute the attention between an input query and the cached keys/values using PagedAttention."); - m.def( - "paged_attention_v2", - &paged_attention_v2, - "PagedAttention V2."); -} diff --git a/csrc/cache.cpp b/csrc/cache.h similarity index 58% rename from csrc/cache.cpp rename to csrc/cache.h index 9ae17bb2985c6..da49d9103214b 100644 --- a/csrc/cache.cpp +++ b/csrc/cache.h @@ -26,22 +26,3 @@ void gather_cached_kv( torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def( - "swap_blocks", - &swap_blocks, - "Swap in (out) the cache blocks from src to dst"); - m.def( - "copy_blocks", - ©_blocks, - "Copy the cache blocks from src to dst"); - m.def( - "reshape_and_cache", - &reshape_and_cache, - "Reshape the key and value tensors and cache them"); - m.def( - "gather_cached_kv", - &gather_cached_kv, - "Gather key and value from the cache into contiguous QKV tensors"); -} diff --git a/csrc/cuda_utils.cpp b/csrc/cuda_utils.cpp deleted file mode 100644 index e7f22ec89d7b4..0000000000000 --- a/csrc/cuda_utils.cpp +++ /dev/null @@ -1,13 +0,0 @@ -#include - -int get_device_attribute( - int attribute, - int device_id); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def( - "get_device_attribute", - &get_device_attribute, - "Gets the specified device attribute."); -} - diff --git a/csrc/cuda_utils.h b/csrc/cuda_utils.h new file mode 100644 index 0000000000000..85cb199b9aa0c --- /dev/null +++ b/csrc/cuda_utils.h @@ -0,0 +1,5 @@ +#include + +int get_device_attribute( + int attribute, + int device_id); diff --git a/csrc/layernorm.cpp b/csrc/layernorm.cpp deleted file mode 100644 index c341a7097962c..0000000000000 --- a/csrc/layernorm.cpp +++ /dev/null @@ -1,24 +0,0 @@ -#include - -void rms_norm( - torch::Tensor& out, - torch::Tensor& input, - torch::Tensor& weight, - float epsilon); - -void fused_add_rms_norm( - torch::Tensor& input, - torch::Tensor& residual, - torch::Tensor& weight, - float epsilon); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def( - "rms_norm", - &rms_norm, - "Apply Root Mean Square (RMS) Normalization to the input tensor."); - m.def( - "fused_add_rms_norm", - &fused_add_rms_norm, - "In-place fused Add and RMS Normalization"); -} diff --git a/csrc/ops.h b/csrc/ops.h new file mode 100644 index 0000000000000..cfb18fbefd7a9 --- /dev/null +++ b/csrc/ops.h @@ -0,0 +1,75 @@ +#include + +void paged_attention_v1( + torch::Tensor& out, + torch::Tensor& query, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& head_mapping, + float scale, + torch::Tensor& block_tables, + torch::Tensor& context_lens, + int block_size, + int max_context_len, + const c10::optional& alibi_slopes); + +void paged_attention_v2( + torch::Tensor& out, + torch::Tensor& exp_sums, + torch::Tensor& max_logits, + torch::Tensor& tmp_out, + torch::Tensor& query, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& head_mapping, + float scale, + torch::Tensor& block_tables, + torch::Tensor& context_lens, + int block_size, + int max_context_len, + const c10::optional& alibi_slopes); + +void rms_norm( + torch::Tensor& out, + torch::Tensor& input, + torch::Tensor& weight, + float epsilon); + +void fused_add_rms_norm( + torch::Tensor& input, + torch::Tensor& residual, + torch::Tensor& weight, + float epsilon); + +void rotary_embedding( + torch::Tensor& positions, + torch::Tensor& query, + torch::Tensor& key, + int head_size, + torch::Tensor& cos_sin_cache, + bool is_neox); + +void silu_and_mul( + torch::Tensor& out, + torch::Tensor& input); + +void gelu_new( + torch::Tensor& out, + torch::Tensor& input); + +void gelu_fast( + torch::Tensor& out, + torch::Tensor& input); + +torch::Tensor awq_gemm( + torch::Tensor _in_feats, + torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, + int split_k_iters); + +void squeezellm_gemm( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor lookup_table); diff --git a/csrc/pos_encoding.cpp b/csrc/pos_encoding.cpp deleted file mode 100644 index eee0cf0d0fa09..0000000000000 --- a/csrc/pos_encoding.cpp +++ /dev/null @@ -1,16 +0,0 @@ -#include - -void rotary_embedding( - torch::Tensor& positions, - torch::Tensor& query, - torch::Tensor& key, - int head_size, - torch::Tensor& cos_sin_cache, - bool is_neox); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def( - "rotary_embedding", - &rotary_embedding, - "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); -} diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp new file mode 100644 index 0000000000000..9e31429690021 --- /dev/null +++ b/csrc/pybind.cpp @@ -0,0 +1,80 @@ +#include "cache.h" +#include "cuda_utils.h" +#include "ops.h" +#include + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + // vLLM custom ops + pybind11::module ops = m.def_submodule("ops", "vLLM custom operators"); + + // Attention ops + ops.def( + "paged_attention_v1", + &paged_attention_v1, + "Compute the attention between an input query and the cached keys/values using PagedAttention."); + ops.def( + "paged_attention_v2", + &paged_attention_v2, + "PagedAttention V2."); + + // Activation ops + ops.def( + "silu_and_mul", + &silu_and_mul, + "Activation function used in SwiGLU."); + ops.def( + "gelu_new", + &gelu_new, + "GELU implementation used in GPT-2."); + ops.def( + "gelu_fast", + &gelu_fast, + "Approximate GELU implementation."); + + // Layernorm + ops.def( + "rms_norm", + &rms_norm, + "Apply Root Mean Square (RMS) Normalization to the input tensor."); + + ops.def( + "fused_add_rms_norm", + &fused_add_rms_norm, + "In-place fused Add and RMS Normalization"); + + // Rotary embedding + ops.def( + "rotary_embedding", + &rotary_embedding, + "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); + + // Quantization ops + ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); + ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); + + // Cache ops + pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); + cache_ops.def( + "swap_blocks", + &swap_blocks, + "Swap in (out) the cache blocks from src to dst"); + cache_ops.def( + "copy_blocks", + ©_blocks, + "Copy the cache blocks from src to dst"); + cache_ops.def( + "reshape_and_cache", + &reshape_and_cache, + "Reshape the key and value tensors and cache them"); + cache_ops.def( + "gather_cached_kv", + &gather_cached_kv, + "Gather key and value from the cache into contiguous QKV tensors"); + + // Cuda utils + pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils"); + cuda_utils.def( + "get_device_attribute", + &get_device_attribute, + "Gets the specified device attribute."); +} diff --git a/csrc/quantization.cpp b/csrc/quantization.cpp deleted file mode 100644 index dfe17a496c780..0000000000000 --- a/csrc/quantization.cpp +++ /dev/null @@ -1,19 +0,0 @@ -#include - -torch::Tensor awq_gemm( - torch::Tensor _in_feats, - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - int split_k_iters); - -void squeezellm_gemm( - torch::Tensor vec, - torch::Tensor mat, - torch::Tensor mul, - torch::Tensor lookup_table); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); - m.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); -} diff --git a/setup.py b/setup.py index 36f4913435628..2b040e88f0aa4 100644 --- a/setup.py +++ b/setup.py @@ -142,87 +142,25 @@ if nvcc_cuda_version >= Version("11.2"): NVCC_FLAGS += ["--threads", str(num_threads)] ext_modules = [] - -# Cache operations. -cache_extension = CUDAExtension( - name="vllm.cache_ops", - sources=["csrc/cache.cpp", "csrc/cache_kernels.cu"], - extra_compile_args={ - "cxx": CXX_FLAGS, - "nvcc": NVCC_FLAGS, - }, -) -ext_modules.append(cache_extension) - -# Attention kernels. -attention_extension = CUDAExtension( - name="vllm.attention_ops", - sources=["csrc/attention.cpp", "csrc/attention/attention_kernels.cu"], - extra_compile_args={ - "cxx": CXX_FLAGS, - "nvcc": NVCC_FLAGS, - }, -) -ext_modules.append(attention_extension) - -# Positional encoding kernels. -positional_encoding_extension = CUDAExtension( - name="vllm.pos_encoding_ops", - sources=["csrc/pos_encoding.cpp", "csrc/pos_encoding_kernels.cu"], - extra_compile_args={ - "cxx": CXX_FLAGS, - "nvcc": NVCC_FLAGS, - }, -) -ext_modules.append(positional_encoding_extension) - -# Layer normalization kernels. -layernorm_extension = CUDAExtension( - name="vllm.layernorm_ops", - sources=["csrc/layernorm.cpp", "csrc/layernorm_kernels.cu"], - extra_compile_args={ - "cxx": CXX_FLAGS, - "nvcc": NVCC_FLAGS, - }, -) -ext_modules.append(layernorm_extension) - -# Activation kernels. -activation_extension = CUDAExtension( - name="vllm.activation_ops", - sources=["csrc/activation.cpp", "csrc/activation_kernels.cu"], - extra_compile_args={ - "cxx": CXX_FLAGS, - "nvcc": NVCC_FLAGS, - }, -) -ext_modules.append(activation_extension) - -# Quantization kernels. -quantization_extension = CUDAExtension( - name="vllm.quantization_ops", +vllm_extension = CUDAExtension( + name="vllm._C", sources=[ - "csrc/quantization.cpp", + "csrc/cache_kernels.cu", + "csrc/attention/attention_kernels.cu", + "csrc/pos_encoding_kernels.cu", + "csrc/activation_kernels.cu", + "csrc/layernorm_kernels.cu", "csrc/quantization/awq/gemm_kernels.cu", "csrc/quantization/squeezellm/quant_cuda_kernel.cu", + "csrc/cuda_utils_kernels.cu", + "csrc/pybind.cpp", ], extra_compile_args={ "cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS, }, ) -ext_modules.append(quantization_extension) - -# Misc. CUDA utils. -cuda_utils_extension = CUDAExtension( - name="vllm.cuda_utils", - sources=["csrc/cuda_utils.cpp", "csrc/cuda_utils_kernels.cu"], - extra_compile_args={ - "cxx": CXX_FLAGS, - "nvcc": NVCC_FLAGS, - }, -) -ext_modules.append(cuda_utils_extension) +ext_modules.append(vllm_extension) def get_path(*filepath) -> str: diff --git a/tests/kernels/test_activation.py b/tests/kernels/test_activation.py index 0b3ad0aa255a1..978b377ea94d4 100644 --- a/tests/kernels/test_activation.py +++ b/tests/kernels/test_activation.py @@ -3,7 +3,7 @@ import torch import torch.nn.functional as F from transformers.activations import get_activation -from vllm import activation_ops +from vllm._C import ops DTYPES = [torch.half, torch.bfloat16, torch.float] NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing @@ -31,7 +31,7 @@ def test_silu_and_mul( torch.cuda.manual_seed(seed) x = torch.randn(num_tokens, 2 * d, dtype=dtype, device="cuda") out = torch.empty(num_tokens, d, dtype=dtype, device="cuda") - activation_ops.silu_and_mul(out, x) + ops.silu_and_mul(out, x) ref_out = ref_silu_and_mul(x) assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) @@ -51,7 +51,7 @@ def test_gelu_new( torch.cuda.manual_seed(seed) x = torch.randn(num_tokens, d, dtype=dtype, device="cuda") out = torch.empty(num_tokens, d, dtype=dtype, device="cuda") - activation_ops.gelu_new(out, x) + ops.gelu_new(out, x) ref_out = get_activation("gelu_new")(x) assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) @@ -70,6 +70,6 @@ def test_gelu_fast( torch.cuda.manual_seed(seed) x = torch.randn(num_tokens, d, dtype=dtype, device="cuda") out = torch.empty(num_tokens, d, dtype=dtype, device="cuda") - activation_ops.gelu_fast(out, x) + ops.gelu_fast(out, x) ref_out = get_activation("gelu_fast")(x) assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index e76416d88311d..a65d4d54d7c82 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -6,7 +6,7 @@ import torch from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask -from vllm import attention_ops +from vllm._C import ops from vllm.utils import get_max_shared_memory_bytes FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 @@ -165,7 +165,7 @@ def test_paged_attention( # Call the paged attention kernel. output = torch.empty_like(query) if version == "v1": - attention_ops.paged_attention_v1( + ops.paged_attention_v1( output, query, key_cache, @@ -194,7 +194,7 @@ def test_paged_attention( device=output.device, ) max_logits = torch.empty_like(exp_sums) - attention_ops.paged_attention_v2( + ops.paged_attention_v2( output, exp_sums, max_logits, diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index e15e7ba91bcb0..9b5d7687a3fec 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -3,7 +3,7 @@ import random import pytest import torch -from vllm import cache_ops +from vllm._C import cache_ops DTYPES = [torch.half, torch.bfloat16, torch.float] NUM_TOKENS = [83] # Arbitrary values for testing diff --git a/tests/kernels/test_layernorm.py b/tests/kernels/test_layernorm.py index a63ef5cc76ffd..ee5228d68e4db 100644 --- a/tests/kernels/test_layernorm.py +++ b/tests/kernels/test_layernorm.py @@ -2,7 +2,7 @@ import pytest import torch import torch.nn as nn -from vllm import layernorm_ops +from vllm._C import ops DTYPES = [torch.half, torch.bfloat16, torch.float] HIDDEN_SIZES = [67, 768, 2048, 5120, 8192] # Arbitrary values for testing @@ -48,7 +48,7 @@ def test_rms_norm( ref = RefRMSNorm(hidden_size).to(dtype).cuda() out = torch.empty_like(x) - layernorm_ops.rms_norm( + ops.rms_norm( out, x, ref.weight.data, diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index d660417440844..7d22bdab4625b 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from vllm import pos_encoding_ops +from vllm._C import ops IS_NEOX_STYLE = [True, False] DTYPES = [torch.half, torch.bfloat16, torch.float] @@ -145,7 +145,7 @@ def test_rotary_embedding( # Run the kernel. The kernel is in-place, so we need to clone the inputs. out_query = query.clone() out_key = key.clone() - pos_encoding_ops.rotary_embedding( + ops.rotary_embedding( positions, out_query, out_key, diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index ecab0c8d3256a..5c0def823edea 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -4,7 +4,7 @@ from typing import Optional import torch import torch.nn as nn -from vllm import activation_ops +from vllm._C import ops from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) @@ -26,7 +26,7 @@ class SiluAndMul(nn.Module): d = x.shape[-1] // 2 output_shape = (x.shape[:-1] + (d, )) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) - activation_ops.silu_and_mul(out, x) + ops.silu_and_mul(out, x) return out @@ -34,7 +34,7 @@ class NewGELU(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: out = torch.empty_like(x) - activation_ops.gelu_new(out, x) + ops.gelu_new(out, x) return out @@ -42,7 +42,7 @@ class FastGELU(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: out = torch.empty_like(x) - activation_ops.gelu_fast(out, x) + ops.gelu_fast(out, x) return out diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index e51bb311decd9..63271ba5b9327 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -7,8 +7,8 @@ from xformers import ops as xops from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask, LowerTriangularMaskWithTensorBias) -from vllm import attention_ops -from vllm import cache_ops +from vllm._C import ops +from vllm._C import cache_ops from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.rotary_embedding import get_rope @@ -163,7 +163,7 @@ class PagedAttention(nn.Module): max_num_partitions == 1 or num_seqs * num_heads > 512) if use_v1: # Run PagedAttention V1. - attention_ops.paged_attention_v1( + ops.paged_attention_v1( output, query, key_cache, @@ -190,7 +190,7 @@ class PagedAttention(nn.Module): device=output.device, ) max_logits = torch.empty_like(exp_sums) - attention_ops.paged_attention_v2( + ops.paged_attention_v2( output, exp_sums, max_logits, diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 275efa0b7dc3f..69fba087099ef 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -4,7 +4,7 @@ from typing import Optional, Tuple, Union import torch import torch.nn as nn -from vllm import layernorm_ops +from vllm._C import ops class RMSNorm(nn.Module): @@ -29,7 +29,7 @@ class RMSNorm(nn.Module): residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if residual is not None: - layernorm_ops.fused_add_rms_norm( + ops.fused_add_rms_norm( x, residual, self.weight.data, @@ -37,7 +37,7 @@ class RMSNorm(nn.Module): ) return x, residual out = torch.empty_like(x) - layernorm_ops.rms_norm( + ops.rms_norm( out, x, self.weight.data, diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 0ab5819d930aa..95d419e64f049 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional import torch from torch.nn.parameter import Parameter -from vllm import quantization_ops +from vllm._C import ops from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import QuantizationConfig @@ -151,8 +151,7 @@ class AWQLinearMethod(LinearMethodBase): pack_factor = self.quant_config.pack_factor out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, )) reshaped_x = x.reshape(-1, x.shape[-1]) - out = quantization_ops.awq_gemm(reshaped_x, qweight, scales, qzeros, - pack_factor) + out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor) if bias is not None: out = out + bias return out.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index 61ec8b79b6ddc..aa6bd0652424f 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional import torch from torch.nn.parameter import Parameter -from vllm import quantization_ops +from vllm._C import ops from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import QuantizationConfig @@ -116,8 +116,7 @@ class SqueezeLLMLinearMethod(LinearMethodBase): reshaped_x = x.reshape(-1, x.shape[-1]) # NOTE: The output tensor should be zero-initialized. out = torch.zeros(out_shape, device="cuda", dtype=torch.float16) - quantization_ops.squeezellm_gemm(reshaped_x, qweight, out, - lookup_table) + ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table) if bias is not None: out = out + bias diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 1b88e9a3b8057..162bb0b533e4f 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -27,7 +27,7 @@ from typing import Any, Dict, Optional, Tuple, Union import torch import torch.nn as nn -from vllm import pos_encoding_ops +from vllm._C import ops class RotaryEmbedding(nn.Module): @@ -87,11 +87,10 @@ class RotaryEmbedding(nn.Module): query: torch.Tensor, key: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: - # pos_encoding_ops.rotary_embedding() is an in-place operation that + # ops.rotary_embedding() is an in-place operation that # updates the query and key tensors. - pos_encoding_ops.rotary_embedding(positions, query, key, - self.head_size, self.cos_sin_cache, - self.is_neox_style) + ops.rotary_embedding(positions, query, key, self.head_size, + self.cos_sin_cache, self.is_neox_style) return query, key diff --git a/vllm/utils.py b/vllm/utils.py index 34d3084856af8..47e51048fed45 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -5,7 +5,7 @@ from platform import uname import psutil import torch -from vllm import cuda_utils +from vllm._C import cuda_utils class Device(enum.Enum): diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index cdb7902082653..1dd0243f8f3a3 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -3,7 +3,7 @@ from typing import Dict, List, Tuple import torch -from vllm import cache_ops +from vllm._C import cache_ops from vllm.config import CacheConfig, ModelConfig, ParallelConfig from vllm.logger import init_logger from vllm.utils import in_wsl