mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 10:46:18 +08:00
[Build] Avoid building too many extensions (#1624)
This commit is contained in:
parent
de23687d16
commit
e0c6f556e8
@ -4,7 +4,7 @@ import time
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import attention_ops
|
from vllm._C import ops
|
||||||
|
|
||||||
NUM_BLOCKS = 1024
|
NUM_BLOCKS = 1024
|
||||||
PARTITION_SIZE = 512
|
PARTITION_SIZE = 512
|
||||||
@ -98,7 +98,7 @@ def main(
|
|||||||
|
|
||||||
for _ in range(num_iters):
|
for _ in range(num_iters):
|
||||||
if version == "v1":
|
if version == "v1":
|
||||||
attention_ops.paged_attention_v1(
|
ops.paged_attention_v1(
|
||||||
output,
|
output,
|
||||||
query,
|
query,
|
||||||
key_cache,
|
key_cache,
|
||||||
@ -112,7 +112,7 @@ def main(
|
|||||||
alibi_slopes,
|
alibi_slopes,
|
||||||
)
|
)
|
||||||
elif version == "v2":
|
elif version == "v2":
|
||||||
attention_ops.paged_attention_v2(
|
ops.paged_attention_v2(
|
||||||
output,
|
output,
|
||||||
exp_sums,
|
exp_sums,
|
||||||
max_logits,
|
max_logits,
|
||||||
|
|||||||
@ -1,28 +0,0 @@
|
|||||||
#include <torch/extension.h>
|
|
||||||
|
|
||||||
void silu_and_mul(
|
|
||||||
torch::Tensor& out,
|
|
||||||
torch::Tensor& input);
|
|
||||||
|
|
||||||
void gelu_new(
|
|
||||||
torch::Tensor& out,
|
|
||||||
torch::Tensor& input);
|
|
||||||
|
|
||||||
void gelu_fast(
|
|
||||||
torch::Tensor& out,
|
|
||||||
torch::Tensor& input);
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
||||||
m.def(
|
|
||||||
"silu_and_mul",
|
|
||||||
&silu_and_mul,
|
|
||||||
"Activation function used in SwiGLU.");
|
|
||||||
m.def(
|
|
||||||
"gelu_new",
|
|
||||||
&gelu_new,
|
|
||||||
"GELU implementation used in GPT-2.");
|
|
||||||
m.def(
|
|
||||||
"gelu_fast",
|
|
||||||
&gelu_fast,
|
|
||||||
"Approximate GELU implementation.");
|
|
||||||
}
|
|
||||||
@ -1,42 +0,0 @@
|
|||||||
#include <torch/extension.h>
|
|
||||||
#include <c10/util/Optional.h>
|
|
||||||
|
|
||||||
void paged_attention_v1(
|
|
||||||
torch::Tensor& out,
|
|
||||||
torch::Tensor& query,
|
|
||||||
torch::Tensor& key_cache,
|
|
||||||
torch::Tensor& value_cache,
|
|
||||||
torch::Tensor& head_mapping,
|
|
||||||
float scale,
|
|
||||||
torch::Tensor& block_tables,
|
|
||||||
torch::Tensor& context_lens,
|
|
||||||
int block_size,
|
|
||||||
int max_context_len,
|
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes);
|
|
||||||
|
|
||||||
void paged_attention_v2(
|
|
||||||
torch::Tensor& out,
|
|
||||||
torch::Tensor& exp_sums,
|
|
||||||
torch::Tensor& max_logits,
|
|
||||||
torch::Tensor& tmp_out,
|
|
||||||
torch::Tensor& query,
|
|
||||||
torch::Tensor& key_cache,
|
|
||||||
torch::Tensor& value_cache,
|
|
||||||
torch::Tensor& head_mapping,
|
|
||||||
float scale,
|
|
||||||
torch::Tensor& block_tables,
|
|
||||||
torch::Tensor& context_lens,
|
|
||||||
int block_size,
|
|
||||||
int max_context_len,
|
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes);
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
||||||
m.def(
|
|
||||||
"paged_attention_v1",
|
|
||||||
&paged_attention_v1,
|
|
||||||
"Compute the attention between an input query and the cached keys/values using PagedAttention.");
|
|
||||||
m.def(
|
|
||||||
"paged_attention_v2",
|
|
||||||
&paged_attention_v2,
|
|
||||||
"PagedAttention V2.");
|
|
||||||
}
|
|
||||||
@ -26,22 +26,3 @@ void gather_cached_kv(
|
|||||||
torch::Tensor& key_cache,
|
torch::Tensor& key_cache,
|
||||||
torch::Tensor& value_cache,
|
torch::Tensor& value_cache,
|
||||||
torch::Tensor& slot_mapping);
|
torch::Tensor& slot_mapping);
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
||||||
m.def(
|
|
||||||
"swap_blocks",
|
|
||||||
&swap_blocks,
|
|
||||||
"Swap in (out) the cache blocks from src to dst");
|
|
||||||
m.def(
|
|
||||||
"copy_blocks",
|
|
||||||
©_blocks,
|
|
||||||
"Copy the cache blocks from src to dst");
|
|
||||||
m.def(
|
|
||||||
"reshape_and_cache",
|
|
||||||
&reshape_and_cache,
|
|
||||||
"Reshape the key and value tensors and cache them");
|
|
||||||
m.def(
|
|
||||||
"gather_cached_kv",
|
|
||||||
&gather_cached_kv,
|
|
||||||
"Gather key and value from the cache into contiguous QKV tensors");
|
|
||||||
}
|
|
||||||
@ -1,13 +0,0 @@
|
|||||||
#include <torch/extension.h>
|
|
||||||
|
|
||||||
int get_device_attribute(
|
|
||||||
int attribute,
|
|
||||||
int device_id);
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
||||||
m.def(
|
|
||||||
"get_device_attribute",
|
|
||||||
&get_device_attribute,
|
|
||||||
"Gets the specified device attribute.");
|
|
||||||
}
|
|
||||||
|
|
||||||
5
csrc/cuda_utils.h
Normal file
5
csrc/cuda_utils.h
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
int get_device_attribute(
|
||||||
|
int attribute,
|
||||||
|
int device_id);
|
||||||
@ -1,24 +0,0 @@
|
|||||||
#include <torch/extension.h>
|
|
||||||
|
|
||||||
void rms_norm(
|
|
||||||
torch::Tensor& out,
|
|
||||||
torch::Tensor& input,
|
|
||||||
torch::Tensor& weight,
|
|
||||||
float epsilon);
|
|
||||||
|
|
||||||
void fused_add_rms_norm(
|
|
||||||
torch::Tensor& input,
|
|
||||||
torch::Tensor& residual,
|
|
||||||
torch::Tensor& weight,
|
|
||||||
float epsilon);
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
||||||
m.def(
|
|
||||||
"rms_norm",
|
|
||||||
&rms_norm,
|
|
||||||
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
|
|
||||||
m.def(
|
|
||||||
"fused_add_rms_norm",
|
|
||||||
&fused_add_rms_norm,
|
|
||||||
"In-place fused Add and RMS Normalization");
|
|
||||||
}
|
|
||||||
75
csrc/ops.h
Normal file
75
csrc/ops.h
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
void paged_attention_v1(
|
||||||
|
torch::Tensor& out,
|
||||||
|
torch::Tensor& query,
|
||||||
|
torch::Tensor& key_cache,
|
||||||
|
torch::Tensor& value_cache,
|
||||||
|
torch::Tensor& head_mapping,
|
||||||
|
float scale,
|
||||||
|
torch::Tensor& block_tables,
|
||||||
|
torch::Tensor& context_lens,
|
||||||
|
int block_size,
|
||||||
|
int max_context_len,
|
||||||
|
const c10::optional<torch::Tensor>& alibi_slopes);
|
||||||
|
|
||||||
|
void paged_attention_v2(
|
||||||
|
torch::Tensor& out,
|
||||||
|
torch::Tensor& exp_sums,
|
||||||
|
torch::Tensor& max_logits,
|
||||||
|
torch::Tensor& tmp_out,
|
||||||
|
torch::Tensor& query,
|
||||||
|
torch::Tensor& key_cache,
|
||||||
|
torch::Tensor& value_cache,
|
||||||
|
torch::Tensor& head_mapping,
|
||||||
|
float scale,
|
||||||
|
torch::Tensor& block_tables,
|
||||||
|
torch::Tensor& context_lens,
|
||||||
|
int block_size,
|
||||||
|
int max_context_len,
|
||||||
|
const c10::optional<torch::Tensor>& alibi_slopes);
|
||||||
|
|
||||||
|
void rms_norm(
|
||||||
|
torch::Tensor& out,
|
||||||
|
torch::Tensor& input,
|
||||||
|
torch::Tensor& weight,
|
||||||
|
float epsilon);
|
||||||
|
|
||||||
|
void fused_add_rms_norm(
|
||||||
|
torch::Tensor& input,
|
||||||
|
torch::Tensor& residual,
|
||||||
|
torch::Tensor& weight,
|
||||||
|
float epsilon);
|
||||||
|
|
||||||
|
void rotary_embedding(
|
||||||
|
torch::Tensor& positions,
|
||||||
|
torch::Tensor& query,
|
||||||
|
torch::Tensor& key,
|
||||||
|
int head_size,
|
||||||
|
torch::Tensor& cos_sin_cache,
|
||||||
|
bool is_neox);
|
||||||
|
|
||||||
|
void silu_and_mul(
|
||||||
|
torch::Tensor& out,
|
||||||
|
torch::Tensor& input);
|
||||||
|
|
||||||
|
void gelu_new(
|
||||||
|
torch::Tensor& out,
|
||||||
|
torch::Tensor& input);
|
||||||
|
|
||||||
|
void gelu_fast(
|
||||||
|
torch::Tensor& out,
|
||||||
|
torch::Tensor& input);
|
||||||
|
|
||||||
|
torch::Tensor awq_gemm(
|
||||||
|
torch::Tensor _in_feats,
|
||||||
|
torch::Tensor _kernel,
|
||||||
|
torch::Tensor _scaling_factors,
|
||||||
|
torch::Tensor _zeros,
|
||||||
|
int split_k_iters);
|
||||||
|
|
||||||
|
void squeezellm_gemm(
|
||||||
|
torch::Tensor vec,
|
||||||
|
torch::Tensor mat,
|
||||||
|
torch::Tensor mul,
|
||||||
|
torch::Tensor lookup_table);
|
||||||
@ -1,16 +0,0 @@
|
|||||||
#include <torch/extension.h>
|
|
||||||
|
|
||||||
void rotary_embedding(
|
|
||||||
torch::Tensor& positions,
|
|
||||||
torch::Tensor& query,
|
|
||||||
torch::Tensor& key,
|
|
||||||
int head_size,
|
|
||||||
torch::Tensor& cos_sin_cache,
|
|
||||||
bool is_neox);
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
||||||
m.def(
|
|
||||||
"rotary_embedding",
|
|
||||||
&rotary_embedding,
|
|
||||||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
|
||||||
}
|
|
||||||
80
csrc/pybind.cpp
Normal file
80
csrc/pybind.cpp
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
#include "cache.h"
|
||||||
|
#include "cuda_utils.h"
|
||||||
|
#include "ops.h"
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
// vLLM custom ops
|
||||||
|
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
|
||||||
|
|
||||||
|
// Attention ops
|
||||||
|
ops.def(
|
||||||
|
"paged_attention_v1",
|
||||||
|
&paged_attention_v1,
|
||||||
|
"Compute the attention between an input query and the cached keys/values using PagedAttention.");
|
||||||
|
ops.def(
|
||||||
|
"paged_attention_v2",
|
||||||
|
&paged_attention_v2,
|
||||||
|
"PagedAttention V2.");
|
||||||
|
|
||||||
|
// Activation ops
|
||||||
|
ops.def(
|
||||||
|
"silu_and_mul",
|
||||||
|
&silu_and_mul,
|
||||||
|
"Activation function used in SwiGLU.");
|
||||||
|
ops.def(
|
||||||
|
"gelu_new",
|
||||||
|
&gelu_new,
|
||||||
|
"GELU implementation used in GPT-2.");
|
||||||
|
ops.def(
|
||||||
|
"gelu_fast",
|
||||||
|
&gelu_fast,
|
||||||
|
"Approximate GELU implementation.");
|
||||||
|
|
||||||
|
// Layernorm
|
||||||
|
ops.def(
|
||||||
|
"rms_norm",
|
||||||
|
&rms_norm,
|
||||||
|
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
|
||||||
|
|
||||||
|
ops.def(
|
||||||
|
"fused_add_rms_norm",
|
||||||
|
&fused_add_rms_norm,
|
||||||
|
"In-place fused Add and RMS Normalization");
|
||||||
|
|
||||||
|
// Rotary embedding
|
||||||
|
ops.def(
|
||||||
|
"rotary_embedding",
|
||||||
|
&rotary_embedding,
|
||||||
|
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
||||||
|
|
||||||
|
// Quantization ops
|
||||||
|
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
||||||
|
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
|
||||||
|
|
||||||
|
// Cache ops
|
||||||
|
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
|
||||||
|
cache_ops.def(
|
||||||
|
"swap_blocks",
|
||||||
|
&swap_blocks,
|
||||||
|
"Swap in (out) the cache blocks from src to dst");
|
||||||
|
cache_ops.def(
|
||||||
|
"copy_blocks",
|
||||||
|
©_blocks,
|
||||||
|
"Copy the cache blocks from src to dst");
|
||||||
|
cache_ops.def(
|
||||||
|
"reshape_and_cache",
|
||||||
|
&reshape_and_cache,
|
||||||
|
"Reshape the key and value tensors and cache them");
|
||||||
|
cache_ops.def(
|
||||||
|
"gather_cached_kv",
|
||||||
|
&gather_cached_kv,
|
||||||
|
"Gather key and value from the cache into contiguous QKV tensors");
|
||||||
|
|
||||||
|
// Cuda utils
|
||||||
|
pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils");
|
||||||
|
cuda_utils.def(
|
||||||
|
"get_device_attribute",
|
||||||
|
&get_device_attribute,
|
||||||
|
"Gets the specified device attribute.");
|
||||||
|
}
|
||||||
@ -1,19 +0,0 @@
|
|||||||
#include <torch/extension.h>
|
|
||||||
|
|
||||||
torch::Tensor awq_gemm(
|
|
||||||
torch::Tensor _in_feats,
|
|
||||||
torch::Tensor _kernel,
|
|
||||||
torch::Tensor _scaling_factors,
|
|
||||||
torch::Tensor _zeros,
|
|
||||||
int split_k_iters);
|
|
||||||
|
|
||||||
void squeezellm_gemm(
|
|
||||||
torch::Tensor vec,
|
|
||||||
torch::Tensor mat,
|
|
||||||
torch::Tensor mul,
|
|
||||||
torch::Tensor lookup_table);
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
||||||
m.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
|
||||||
m.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
|
|
||||||
}
|
|
||||||
82
setup.py
82
setup.py
@ -142,87 +142,25 @@ if nvcc_cuda_version >= Version("11.2"):
|
|||||||
NVCC_FLAGS += ["--threads", str(num_threads)]
|
NVCC_FLAGS += ["--threads", str(num_threads)]
|
||||||
|
|
||||||
ext_modules = []
|
ext_modules = []
|
||||||
|
vllm_extension = CUDAExtension(
|
||||||
# Cache operations.
|
name="vllm._C",
|
||||||
cache_extension = CUDAExtension(
|
|
||||||
name="vllm.cache_ops",
|
|
||||||
sources=["csrc/cache.cpp", "csrc/cache_kernels.cu"],
|
|
||||||
extra_compile_args={
|
|
||||||
"cxx": CXX_FLAGS,
|
|
||||||
"nvcc": NVCC_FLAGS,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
ext_modules.append(cache_extension)
|
|
||||||
|
|
||||||
# Attention kernels.
|
|
||||||
attention_extension = CUDAExtension(
|
|
||||||
name="vllm.attention_ops",
|
|
||||||
sources=["csrc/attention.cpp", "csrc/attention/attention_kernels.cu"],
|
|
||||||
extra_compile_args={
|
|
||||||
"cxx": CXX_FLAGS,
|
|
||||||
"nvcc": NVCC_FLAGS,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
ext_modules.append(attention_extension)
|
|
||||||
|
|
||||||
# Positional encoding kernels.
|
|
||||||
positional_encoding_extension = CUDAExtension(
|
|
||||||
name="vllm.pos_encoding_ops",
|
|
||||||
sources=["csrc/pos_encoding.cpp", "csrc/pos_encoding_kernels.cu"],
|
|
||||||
extra_compile_args={
|
|
||||||
"cxx": CXX_FLAGS,
|
|
||||||
"nvcc": NVCC_FLAGS,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
ext_modules.append(positional_encoding_extension)
|
|
||||||
|
|
||||||
# Layer normalization kernels.
|
|
||||||
layernorm_extension = CUDAExtension(
|
|
||||||
name="vllm.layernorm_ops",
|
|
||||||
sources=["csrc/layernorm.cpp", "csrc/layernorm_kernels.cu"],
|
|
||||||
extra_compile_args={
|
|
||||||
"cxx": CXX_FLAGS,
|
|
||||||
"nvcc": NVCC_FLAGS,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
ext_modules.append(layernorm_extension)
|
|
||||||
|
|
||||||
# Activation kernels.
|
|
||||||
activation_extension = CUDAExtension(
|
|
||||||
name="vllm.activation_ops",
|
|
||||||
sources=["csrc/activation.cpp", "csrc/activation_kernels.cu"],
|
|
||||||
extra_compile_args={
|
|
||||||
"cxx": CXX_FLAGS,
|
|
||||||
"nvcc": NVCC_FLAGS,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
ext_modules.append(activation_extension)
|
|
||||||
|
|
||||||
# Quantization kernels.
|
|
||||||
quantization_extension = CUDAExtension(
|
|
||||||
name="vllm.quantization_ops",
|
|
||||||
sources=[
|
sources=[
|
||||||
"csrc/quantization.cpp",
|
"csrc/cache_kernels.cu",
|
||||||
|
"csrc/attention/attention_kernels.cu",
|
||||||
|
"csrc/pos_encoding_kernels.cu",
|
||||||
|
"csrc/activation_kernels.cu",
|
||||||
|
"csrc/layernorm_kernels.cu",
|
||||||
"csrc/quantization/awq/gemm_kernels.cu",
|
"csrc/quantization/awq/gemm_kernels.cu",
|
||||||
"csrc/quantization/squeezellm/quant_cuda_kernel.cu",
|
"csrc/quantization/squeezellm/quant_cuda_kernel.cu",
|
||||||
|
"csrc/cuda_utils_kernels.cu",
|
||||||
|
"csrc/pybind.cpp",
|
||||||
],
|
],
|
||||||
extra_compile_args={
|
extra_compile_args={
|
||||||
"cxx": CXX_FLAGS,
|
"cxx": CXX_FLAGS,
|
||||||
"nvcc": NVCC_FLAGS,
|
"nvcc": NVCC_FLAGS,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
ext_modules.append(quantization_extension)
|
ext_modules.append(vllm_extension)
|
||||||
|
|
||||||
# Misc. CUDA utils.
|
|
||||||
cuda_utils_extension = CUDAExtension(
|
|
||||||
name="vllm.cuda_utils",
|
|
||||||
sources=["csrc/cuda_utils.cpp", "csrc/cuda_utils_kernels.cu"],
|
|
||||||
extra_compile_args={
|
|
||||||
"cxx": CXX_FLAGS,
|
|
||||||
"nvcc": NVCC_FLAGS,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
ext_modules.append(cuda_utils_extension)
|
|
||||||
|
|
||||||
|
|
||||||
def get_path(*filepath) -> str:
|
def get_path(*filepath) -> str:
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers.activations import get_activation
|
from transformers.activations import get_activation
|
||||||
|
|
||||||
from vllm import activation_ops
|
from vllm._C import ops
|
||||||
|
|
||||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||||
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
|
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
|
||||||
@ -31,7 +31,7 @@ def test_silu_and_mul(
|
|||||||
torch.cuda.manual_seed(seed)
|
torch.cuda.manual_seed(seed)
|
||||||
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device="cuda")
|
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device="cuda")
|
||||||
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
|
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
|
||||||
activation_ops.silu_and_mul(out, x)
|
ops.silu_and_mul(out, x)
|
||||||
ref_out = ref_silu_and_mul(x)
|
ref_out = ref_silu_and_mul(x)
|
||||||
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
||||||
|
|
||||||
@ -51,7 +51,7 @@ def test_gelu_new(
|
|||||||
torch.cuda.manual_seed(seed)
|
torch.cuda.manual_seed(seed)
|
||||||
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
|
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
|
||||||
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
|
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
|
||||||
activation_ops.gelu_new(out, x)
|
ops.gelu_new(out, x)
|
||||||
ref_out = get_activation("gelu_new")(x)
|
ref_out = get_activation("gelu_new")(x)
|
||||||
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
||||||
|
|
||||||
@ -70,6 +70,6 @@ def test_gelu_fast(
|
|||||||
torch.cuda.manual_seed(seed)
|
torch.cuda.manual_seed(seed)
|
||||||
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
|
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
|
||||||
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
|
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
|
||||||
activation_ops.gelu_fast(out, x)
|
ops.gelu_fast(out, x)
|
||||||
ref_out = get_activation("gelu_fast")(x)
|
ref_out = get_activation("gelu_fast")(x)
|
||||||
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import torch
|
|||||||
from xformers import ops as xops
|
from xformers import ops as xops
|
||||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||||
|
|
||||||
from vllm import attention_ops
|
from vllm._C import ops
|
||||||
from vllm.utils import get_max_shared_memory_bytes
|
from vllm.utils import get_max_shared_memory_bytes
|
||||||
|
|
||||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||||
@ -165,7 +165,7 @@ def test_paged_attention(
|
|||||||
# Call the paged attention kernel.
|
# Call the paged attention kernel.
|
||||||
output = torch.empty_like(query)
|
output = torch.empty_like(query)
|
||||||
if version == "v1":
|
if version == "v1":
|
||||||
attention_ops.paged_attention_v1(
|
ops.paged_attention_v1(
|
||||||
output,
|
output,
|
||||||
query,
|
query,
|
||||||
key_cache,
|
key_cache,
|
||||||
@ -194,7 +194,7 @@ def test_paged_attention(
|
|||||||
device=output.device,
|
device=output.device,
|
||||||
)
|
)
|
||||||
max_logits = torch.empty_like(exp_sums)
|
max_logits = torch.empty_like(exp_sums)
|
||||||
attention_ops.paged_attention_v2(
|
ops.paged_attention_v2(
|
||||||
output,
|
output,
|
||||||
exp_sums,
|
exp_sums,
|
||||||
max_logits,
|
max_logits,
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import random
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import cache_ops
|
from vllm._C import cache_ops
|
||||||
|
|
||||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||||
NUM_TOKENS = [83] # Arbitrary values for testing
|
NUM_TOKENS = [83] # Arbitrary values for testing
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm import layernorm_ops
|
from vllm._C import ops
|
||||||
|
|
||||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||||
HIDDEN_SIZES = [67, 768, 2048, 5120, 8192] # Arbitrary values for testing
|
HIDDEN_SIZES = [67, 768, 2048, 5120, 8192] # Arbitrary values for testing
|
||||||
@ -48,7 +48,7 @@ def test_rms_norm(
|
|||||||
ref = RefRMSNorm(hidden_size).to(dtype).cuda()
|
ref = RefRMSNorm(hidden_size).to(dtype).cuda()
|
||||||
|
|
||||||
out = torch.empty_like(x)
|
out = torch.empty_like(x)
|
||||||
layernorm_ops.rms_norm(
|
ops.rms_norm(
|
||||||
out,
|
out,
|
||||||
x,
|
x,
|
||||||
ref.weight.data,
|
ref.weight.data,
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from vllm import pos_encoding_ops
|
from vllm._C import ops
|
||||||
|
|
||||||
IS_NEOX_STYLE = [True, False]
|
IS_NEOX_STYLE = [True, False]
|
||||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||||
@ -145,7 +145,7 @@ def test_rotary_embedding(
|
|||||||
# Run the kernel. The kernel is in-place, so we need to clone the inputs.
|
# Run the kernel. The kernel is in-place, so we need to clone the inputs.
|
||||||
out_query = query.clone()
|
out_query = query.clone()
|
||||||
out_key = key.clone()
|
out_key = key.clone()
|
||||||
pos_encoding_ops.rotary_embedding(
|
ops.rotary_embedding(
|
||||||
positions,
|
positions,
|
||||||
out_query,
|
out_query,
|
||||||
out_key,
|
out_key,
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm import activation_ops
|
from vllm._C import ops
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
@ -26,7 +26,7 @@ class SiluAndMul(nn.Module):
|
|||||||
d = x.shape[-1] // 2
|
d = x.shape[-1] // 2
|
||||||
output_shape = (x.shape[:-1] + (d, ))
|
output_shape = (x.shape[:-1] + (d, ))
|
||||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||||
activation_ops.silu_and_mul(out, x)
|
ops.silu_and_mul(out, x)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@ -34,7 +34,7 @@ class NewGELU(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
out = torch.empty_like(x)
|
out = torch.empty_like(x)
|
||||||
activation_ops.gelu_new(out, x)
|
ops.gelu_new(out, x)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@ -42,7 +42,7 @@ class FastGELU(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
out = torch.empty_like(x)
|
out = torch.empty_like(x)
|
||||||
activation_ops.gelu_fast(out, x)
|
ops.gelu_fast(out, x)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -7,8 +7,8 @@ from xformers import ops as xops
|
|||||||
from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
|
from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
|
||||||
LowerTriangularMaskWithTensorBias)
|
LowerTriangularMaskWithTensorBias)
|
||||||
|
|
||||||
from vllm import attention_ops
|
from vllm._C import ops
|
||||||
from vllm import cache_ops
|
from vllm._C import cache_ops
|
||||||
from vllm.model_executor.input_metadata import InputMetadata
|
from vllm.model_executor.input_metadata import InputMetadata
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
|
||||||
@ -163,7 +163,7 @@ class PagedAttention(nn.Module):
|
|||||||
max_num_partitions == 1 or num_seqs * num_heads > 512)
|
max_num_partitions == 1 or num_seqs * num_heads > 512)
|
||||||
if use_v1:
|
if use_v1:
|
||||||
# Run PagedAttention V1.
|
# Run PagedAttention V1.
|
||||||
attention_ops.paged_attention_v1(
|
ops.paged_attention_v1(
|
||||||
output,
|
output,
|
||||||
query,
|
query,
|
||||||
key_cache,
|
key_cache,
|
||||||
@ -190,7 +190,7 @@ class PagedAttention(nn.Module):
|
|||||||
device=output.device,
|
device=output.device,
|
||||||
)
|
)
|
||||||
max_logits = torch.empty_like(exp_sums)
|
max_logits = torch.empty_like(exp_sums)
|
||||||
attention_ops.paged_attention_v2(
|
ops.paged_attention_v2(
|
||||||
output,
|
output,
|
||||||
exp_sums,
|
exp_sums,
|
||||||
max_logits,
|
max_logits,
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from typing import Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm import layernorm_ops
|
from vllm._C import ops
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
class RMSNorm(nn.Module):
|
||||||
@ -29,7 +29,7 @@ class RMSNorm(nn.Module):
|
|||||||
residual: Optional[torch.Tensor] = None,
|
residual: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
layernorm_ops.fused_add_rms_norm(
|
ops.fused_add_rms_norm(
|
||||||
x,
|
x,
|
||||||
residual,
|
residual,
|
||||||
self.weight.data,
|
self.weight.data,
|
||||||
@ -37,7 +37,7 @@ class RMSNorm(nn.Module):
|
|||||||
)
|
)
|
||||||
return x, residual
|
return x, residual
|
||||||
out = torch.empty_like(x)
|
out = torch.empty_like(x)
|
||||||
layernorm_ops.rms_norm(
|
ops.rms_norm(
|
||||||
out,
|
out,
|
||||||
x,
|
x,
|
||||||
self.weight.data,
|
self.weight.data,
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
|
|||||||
import torch
|
import torch
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from vllm import quantization_ops
|
from vllm._C import ops
|
||||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||||
set_weight_attrs)
|
set_weight_attrs)
|
||||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||||
@ -151,8 +151,7 @@ class AWQLinearMethod(LinearMethodBase):
|
|||||||
pack_factor = self.quant_config.pack_factor
|
pack_factor = self.quant_config.pack_factor
|
||||||
out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
|
out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
|
||||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||||
out = quantization_ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
|
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor)
|
||||||
pack_factor)
|
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
out = out + bias
|
out = out + bias
|
||||||
return out.reshape(out_shape)
|
return out.reshape(out_shape)
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
|
|||||||
import torch
|
import torch
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from vllm import quantization_ops
|
from vllm._C import ops
|
||||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||||
set_weight_attrs)
|
set_weight_attrs)
|
||||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||||
@ -116,8 +116,7 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
|
|||||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||||
# NOTE: The output tensor should be zero-initialized.
|
# NOTE: The output tensor should be zero-initialized.
|
||||||
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
|
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
|
||||||
quantization_ops.squeezellm_gemm(reshaped_x, qweight, out,
|
ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table)
|
||||||
lookup_table)
|
|
||||||
|
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
out = out + bias
|
out = out + bias
|
||||||
|
|||||||
@ -27,7 +27,7 @@ from typing import Any, Dict, Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm import pos_encoding_ops
|
from vllm._C import ops
|
||||||
|
|
||||||
|
|
||||||
class RotaryEmbedding(nn.Module):
|
class RotaryEmbedding(nn.Module):
|
||||||
@ -87,11 +87,10 @@ class RotaryEmbedding(nn.Module):
|
|||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# pos_encoding_ops.rotary_embedding() is an in-place operation that
|
# ops.rotary_embedding() is an in-place operation that
|
||||||
# updates the query and key tensors.
|
# updates the query and key tensors.
|
||||||
pos_encoding_ops.rotary_embedding(positions, query, key,
|
ops.rotary_embedding(positions, query, key, self.head_size,
|
||||||
self.head_size, self.cos_sin_cache,
|
self.cos_sin_cache, self.is_neox_style)
|
||||||
self.is_neox_style)
|
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -5,7 +5,7 @@ from platform import uname
|
|||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import cuda_utils
|
from vllm._C import cuda_utils
|
||||||
|
|
||||||
|
|
||||||
class Device(enum.Enum):
|
class Device(enum.Enum):
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from typing import Dict, List, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import cache_ops
|
from vllm._C import cache_ops
|
||||||
from vllm.config import CacheConfig, ModelConfig, ParallelConfig
|
from vllm.config import CacheConfig, ModelConfig, ParallelConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import in_wsl
|
from vllm.utils import in_wsl
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user