mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 18:06:03 +08:00
Add custom kernel for RMS normalization (#16)
This commit is contained in:
parent
c45f3c3ab6
commit
09e9245478
26
cacheflow/models/layernorm.py
Normal file
26
cacheflow/models/layernorm.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from cacheflow import layernorm_ops
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
out = torch.empty_like(x)
|
||||||
|
layernorm_ops.rms_norm(
|
||||||
|
out,
|
||||||
|
x,
|
||||||
|
self.weight.data,
|
||||||
|
self.variance_epsilon,
|
||||||
|
)
|
||||||
|
return out
|
||||||
@ -12,6 +12,7 @@ from transformers import LlamaConfig
|
|||||||
|
|
||||||
from cacheflow.models import InputMetadata
|
from cacheflow.models import InputMetadata
|
||||||
from cacheflow.models.attention import LlamaCacheFlowAttention
|
from cacheflow.models.attention import LlamaCacheFlowAttention
|
||||||
|
from cacheflow.models.layernorm import RMSNorm
|
||||||
from cacheflow.models.sample import Sampler
|
from cacheflow.models.sample import Sampler
|
||||||
from cacheflow.parallel_utils.parallel_state import (
|
from cacheflow.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
@ -23,22 +24,6 @@ from cacheflow.sequence import SequenceOutputs
|
|||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
class LlamaRMSNorm(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
|
||||||
super().__init__()
|
|
||||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
||||||
self.variance_epsilon = eps
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
|
||||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
|
||||||
# convert into half-precision if necessary
|
|
||||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
|
||||||
hidden_states = hidden_states.to(self.weight.dtype)
|
|
||||||
return self.weight * hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaMLP(nn.Module):
|
class LlamaMLP(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -148,8 +133,8 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
intermediate_size=config.intermediate_size,
|
intermediate_size=config.intermediate_size,
|
||||||
hidden_act=config.hidden_act,
|
hidden_act=config.hidden_act,
|
||||||
)
|
)
|
||||||
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -190,7 +175,7 @@ class LlamaModel(nn.Module):
|
|||||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size,
|
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size,
|
||||||
perform_initialization=False)
|
perform_initialization=False)
|
||||||
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||||
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
#include "attention_utils.h"
|
#include "attention_utils.h"
|
||||||
#include "cuda_primitives.h"
|
#include "cuda_primitives.h"
|
||||||
|
#include "reduction_utils.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
|
|||||||
@ -159,45 +159,6 @@ struct Qk_dot<uint16_t, 4> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
template<int WARPS_PER_BLOCK, int WARP_SIZE = 32>
|
|
||||||
inline __device__ float block_sum(float* red_smem, float sum)
|
|
||||||
{
|
|
||||||
|
|
||||||
// Decompose the thread index into warp / lane.
|
|
||||||
int warp = threadIdx.x / WARP_SIZE;
|
|
||||||
int lane = threadIdx.x % WARP_SIZE;
|
|
||||||
|
|
||||||
// Compute the sum per warp.
|
|
||||||
#pragma unroll
|
|
||||||
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
|
|
||||||
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Warp leaders store the data to shared memory.
|
|
||||||
if (lane == 0) {
|
|
||||||
red_smem[warp] = sum;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make sure the data is in shared memory.
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
// The warps compute the final sums.
|
|
||||||
if (lane < WARPS_PER_BLOCK) {
|
|
||||||
sum = red_smem[lane];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parallel reduction inside the warp.
|
|
||||||
#pragma unroll
|
|
||||||
for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
|
|
||||||
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Broadcast to other threads.
|
|
||||||
return __shfl_sync(uint32_t(-1), sum, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace cacheflow
|
} // namespace cacheflow
|
||||||
|
|
||||||
#undef MMHA_USE_FP32_ACUM_FOR_FMA
|
#undef MMHA_USE_FP32_ACUM_FOR_FMA
|
||||||
|
|||||||
14
csrc/layernorm.cpp
Normal file
14
csrc/layernorm.cpp
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
void rms_norm(
|
||||||
|
torch::Tensor& out,
|
||||||
|
torch::Tensor& input,
|
||||||
|
torch::Tensor& weight,
|
||||||
|
float epsilon);
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def(
|
||||||
|
"rms_norm",
|
||||||
|
&rms_norm,
|
||||||
|
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
|
||||||
|
}
|
||||||
61
csrc/layernorm_kernels.cu
Normal file
61
csrc/layernorm_kernels.cu
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
#include <torch/extension.h>
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
|
#include "reduction_utils.h"
|
||||||
|
|
||||||
|
namespace cacheflow {
|
||||||
|
|
||||||
|
// TODO(woosuk): Further optimize this kernel.
|
||||||
|
template<typename scalar_t>
|
||||||
|
__global__ void rms_norm_kernel(
|
||||||
|
scalar_t* __restrict__ out, // [num_tokens, hidden_size]
|
||||||
|
const scalar_t* __restrict__ input, // [num_tokens, hidden_size]
|
||||||
|
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||||
|
const float epsilon,
|
||||||
|
const int num_tokens,
|
||||||
|
const int hidden_size) {
|
||||||
|
__shared__ float s_variance;
|
||||||
|
float variance = 0.0f;
|
||||||
|
|
||||||
|
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||||
|
const float x = (float) input[blockIdx.x * hidden_size + idx];
|
||||||
|
variance += x * x;
|
||||||
|
}
|
||||||
|
variance = blockReduceSum<float>(variance);
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||||
|
float x = (float) input[blockIdx.x * hidden_size + idx];
|
||||||
|
out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cacheflow
|
||||||
|
|
||||||
|
void rms_norm(
|
||||||
|
torch::Tensor& out, // [num_tokens, hidden_size]
|
||||||
|
torch::Tensor& input, // [num_tokens, hidden_size]
|
||||||
|
torch::Tensor& weight, // [hidden_size]
|
||||||
|
float epsilon) {
|
||||||
|
int num_tokens = input.size(0);
|
||||||
|
int hidden_size = input.size(1);
|
||||||
|
|
||||||
|
dim3 grid(num_tokens);
|
||||||
|
dim3 block(std::min(hidden_size, 1024));
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||||
|
input.scalar_type(),
|
||||||
|
"rms_norm_kernel",
|
||||||
|
[&] {
|
||||||
|
cacheflow::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
|
out.data_ptr<scalar_t>(),
|
||||||
|
input.data_ptr<scalar_t>(),
|
||||||
|
weight.data_ptr<scalar_t>(),
|
||||||
|
epsilon,
|
||||||
|
num_tokens,
|
||||||
|
hidden_size);
|
||||||
|
});
|
||||||
|
}
|
||||||
76
csrc/reduction_utils.h
Normal file
76
csrc/reduction_utils.h
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
namespace cacheflow {
|
||||||
|
|
||||||
|
template<int WARPS_PER_BLOCK, int WARP_SIZE = 32>
|
||||||
|
inline __device__ float block_sum(float* red_smem, float sum)
|
||||||
|
{
|
||||||
|
|
||||||
|
// Decompose the thread index into warp / lane.
|
||||||
|
int warp = threadIdx.x / WARP_SIZE;
|
||||||
|
int lane = threadIdx.x % WARP_SIZE;
|
||||||
|
|
||||||
|
// Compute the sum per warp.
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
|
||||||
|
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Warp leaders store the data to shared memory.
|
||||||
|
if (lane == 0) {
|
||||||
|
red_smem[warp] = sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure the data is in shared memory.
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// The warps compute the final sums.
|
||||||
|
if (lane < WARPS_PER_BLOCK) {
|
||||||
|
sum = red_smem[lane];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parallel reduction inside the warp.
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
|
||||||
|
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Broadcast to other threads.
|
||||||
|
return __shfl_sync(uint32_t(-1), sum, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#define FINAL_MASK 0xffffffff
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
__inline__ __device__ T warpReduceSum(T val)
|
||||||
|
{
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = 16; mask > 0; mask >>= 1)
|
||||||
|
val += __shfl_xor_sync(FINAL_MASK, val, mask, 32);
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Calculate the sum of all elements in a block */
|
||||||
|
template<typename T>
|
||||||
|
__inline__ __device__ T blockReduceSum(T val)
|
||||||
|
{
|
||||||
|
static __shared__ T shared[32];
|
||||||
|
int lane = threadIdx.x & 0x1f;
|
||||||
|
int wid = threadIdx.x >> 5;
|
||||||
|
|
||||||
|
val = warpReduceSum<T>(val);
|
||||||
|
|
||||||
|
if (lane == 0)
|
||||||
|
shared[wid] = val;
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
|
||||||
|
// blockDim.x is not divided by 32
|
||||||
|
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
|
||||||
|
val = warpReduceSum<T>(val);
|
||||||
|
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cacheflow
|
||||||
8
setup.py
8
setup.py
@ -31,6 +31,14 @@ positional_encoding_extension = cpp_extension.CUDAExtension(
|
|||||||
)
|
)
|
||||||
ext_modules.append(positional_encoding_extension)
|
ext_modules.append(positional_encoding_extension)
|
||||||
|
|
||||||
|
# Layer normalization kernels.
|
||||||
|
layernorm_extension = cpp_extension.CUDAExtension(
|
||||||
|
name='cacheflow.layernorm_ops',
|
||||||
|
sources=['csrc/layernorm.cpp', 'csrc/layernorm_kernels.cu'],
|
||||||
|
extra_compile_args={'cxx': CXX_FLAGS, 'nvcc': NVCC_FLAGS},
|
||||||
|
)
|
||||||
|
ext_modules.append(layernorm_extension)
|
||||||
|
|
||||||
setuptools.setup(
|
setuptools.setup(
|
||||||
name='cacheflow',
|
name='cacheflow',
|
||||||
ext_modules=ext_modules,
|
ext_modules=ext_modules,
|
||||||
|
|||||||
53
tests/kernels/layernorm.py
Normal file
53
tests/kernels/layernorm.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from cacheflow import layernorm_ops
|
||||||
|
|
||||||
|
|
||||||
|
class RefRMSNorm(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
|
super().__init__()
|
||||||
|
weight = torch.randn(hidden_size) / (hidden_size ** 0.5)
|
||||||
|
self.weight = nn.Parameter(weight)
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
|
if self.weight.dtype in [torch.half, torch.float16, torch.bfloat16]:
|
||||||
|
hidden_states = hidden_states.to(self.weight.dtype)
|
||||||
|
return self.weight * hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_rms_norm(
|
||||||
|
num_tokens: int,
|
||||||
|
hidden_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> None:
|
||||||
|
x = torch.randn(num_tokens, hidden_size, dtype=dtype, device='cuda')
|
||||||
|
ref = RefRMSNorm(hidden_size).to(dtype).cuda()
|
||||||
|
|
||||||
|
out = torch.empty_like(x)
|
||||||
|
layernorm_ops.rms_norm(
|
||||||
|
out,
|
||||||
|
x,
|
||||||
|
ref.weight.data,
|
||||||
|
ref.variance_epsilon,
|
||||||
|
)
|
||||||
|
ref_out = ref(x)
|
||||||
|
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
for dtype in [torch.half, torch.float]:
|
||||||
|
for num_tokens in [7, 128, 2048]:
|
||||||
|
for hidden_size in [13, 64, 1024, 5120]:
|
||||||
|
print(f'Testing RMS kernel with dtype={dtype}, num_tokens='
|
||||||
|
f'{num_tokens}, hidden_size={hidden_size}')
|
||||||
|
test_rms_norm(
|
||||||
|
num_tokens=num_tokens,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
Loading…
x
Reference in New Issue
Block a user