mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:26:00 +08:00
Add custom kernel for RMS normalization (#16)
This commit is contained in:
parent
c45f3c3ab6
commit
09e9245478
26
cacheflow/models/layernorm.py
Normal file
26
cacheflow/models/layernorm.py
Normal file
@ -0,0 +1,26 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from cacheflow import layernorm_ops
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
eps: float = 1e-6,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
out = torch.empty_like(x)
|
||||
layernorm_ops.rms_norm(
|
||||
out,
|
||||
x,
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return out
|
||||
@ -12,6 +12,7 @@ from transformers import LlamaConfig
|
||||
|
||||
from cacheflow.models import InputMetadata
|
||||
from cacheflow.models.attention import LlamaCacheFlowAttention
|
||||
from cacheflow.models.layernorm import RMSNorm
|
||||
from cacheflow.models.sample import Sampler
|
||||
from cacheflow.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
@ -23,22 +24,6 @@ from cacheflow.sequence import SequenceOutputs
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class LlamaRMSNorm(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
# convert into half-precision if necessary
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
return self.weight * hidden_states
|
||||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@ -148,8 +133,8 @@ class LlamaDecoderLayer(nn.Module):
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
)
|
||||
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -190,7 +175,7 @@ class LlamaModel(nn.Module):
|
||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size,
|
||||
perform_initialization=False)
|
||||
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
#include "attention_utils.h"
|
||||
#include "cuda_primitives.h"
|
||||
#include "reduction_utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
|
||||
@ -159,45 +159,6 @@ struct Qk_dot<uint16_t, 4> {
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<int WARPS_PER_BLOCK, int WARP_SIZE = 32>
|
||||
inline __device__ float block_sum(float* red_smem, float sum)
|
||||
{
|
||||
|
||||
// Decompose the thread index into warp / lane.
|
||||
int warp = threadIdx.x / WARP_SIZE;
|
||||
int lane = threadIdx.x % WARP_SIZE;
|
||||
|
||||
// Compute the sum per warp.
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
|
||||
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
||||
}
|
||||
|
||||
// Warp leaders store the data to shared memory.
|
||||
if (lane == 0) {
|
||||
red_smem[warp] = sum;
|
||||
}
|
||||
|
||||
// Make sure the data is in shared memory.
|
||||
__syncthreads();
|
||||
|
||||
// The warps compute the final sums.
|
||||
if (lane < WARPS_PER_BLOCK) {
|
||||
sum = red_smem[lane];
|
||||
}
|
||||
|
||||
// Parallel reduction inside the warp.
|
||||
#pragma unroll
|
||||
for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
|
||||
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
||||
}
|
||||
|
||||
// Broadcast to other threads.
|
||||
return __shfl_sync(uint32_t(-1), sum, 0);
|
||||
}
|
||||
|
||||
} // namespace cacheflow
|
||||
|
||||
#undef MMHA_USE_FP32_ACUM_FOR_FMA
|
||||
|
||||
14
csrc/layernorm.cpp
Normal file
14
csrc/layernorm.cpp
Normal file
@ -0,0 +1,14 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
void rms_norm(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input,
|
||||
torch::Tensor& weight,
|
||||
float epsilon);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"rms_norm",
|
||||
&rms_norm,
|
||||
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
|
||||
}
|
||||
61
csrc/layernorm_kernels.cu
Normal file
61
csrc/layernorm_kernels.cu
Normal file
@ -0,0 +1,61 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "reduction_utils.h"
|
||||
|
||||
namespace cacheflow {
|
||||
|
||||
// TODO(woosuk): Further optimize this kernel.
|
||||
template<typename scalar_t>
|
||||
__global__ void rms_norm_kernel(
|
||||
scalar_t* __restrict__ out, // [num_tokens, hidden_size]
|
||||
const scalar_t* __restrict__ input, // [num_tokens, hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float epsilon,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
__shared__ float s_variance;
|
||||
float variance = 0.0f;
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
const float x = (float) input[blockIdx.x * hidden_size + idx];
|
||||
variance += x * x;
|
||||
}
|
||||
variance = blockReduceSum<float>(variance);
|
||||
if (threadIdx.x == 0) {
|
||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
float x = (float) input[blockIdx.x * hidden_size + idx];
|
||||
out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cacheflow
|
||||
|
||||
void rms_norm(
|
||||
torch::Tensor& out, // [num_tokens, hidden_size]
|
||||
torch::Tensor& input, // [num_tokens, hidden_size]
|
||||
torch::Tensor& weight, // [hidden_size]
|
||||
float epsilon) {
|
||||
int num_tokens = input.size(0);
|
||||
int hidden_size = input.size(1);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(hidden_size, 1024));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
input.scalar_type(),
|
||||
"rms_norm_kernel",
|
||||
[&] {
|
||||
cacheflow::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(),
|
||||
epsilon,
|
||||
num_tokens,
|
||||
hidden_size);
|
||||
});
|
||||
}
|
||||
76
csrc/reduction_utils.h
Normal file
76
csrc/reduction_utils.h
Normal file
@ -0,0 +1,76 @@
|
||||
#pragma once
|
||||
|
||||
namespace cacheflow {
|
||||
|
||||
template<int WARPS_PER_BLOCK, int WARP_SIZE = 32>
|
||||
inline __device__ float block_sum(float* red_smem, float sum)
|
||||
{
|
||||
|
||||
// Decompose the thread index into warp / lane.
|
||||
int warp = threadIdx.x / WARP_SIZE;
|
||||
int lane = threadIdx.x % WARP_SIZE;
|
||||
|
||||
// Compute the sum per warp.
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
|
||||
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
||||
}
|
||||
|
||||
// Warp leaders store the data to shared memory.
|
||||
if (lane == 0) {
|
||||
red_smem[warp] = sum;
|
||||
}
|
||||
|
||||
// Make sure the data is in shared memory.
|
||||
__syncthreads();
|
||||
|
||||
// The warps compute the final sums.
|
||||
if (lane < WARPS_PER_BLOCK) {
|
||||
sum = red_smem[lane];
|
||||
}
|
||||
|
||||
// Parallel reduction inside the warp.
|
||||
#pragma unroll
|
||||
for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
|
||||
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
||||
}
|
||||
|
||||
// Broadcast to other threads.
|
||||
return __shfl_sync(uint32_t(-1), sum, 0);
|
||||
}
|
||||
|
||||
#define FINAL_MASK 0xffffffff
|
||||
|
||||
template<typename T>
|
||||
__inline__ __device__ T warpReduceSum(T val)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1)
|
||||
val += __shfl_xor_sync(FINAL_MASK, val, mask, 32);
|
||||
return val;
|
||||
}
|
||||
|
||||
/* Calculate the sum of all elements in a block */
|
||||
template<typename T>
|
||||
__inline__ __device__ T blockReduceSum(T val)
|
||||
{
|
||||
static __shared__ T shared[32];
|
||||
int lane = threadIdx.x & 0x1f;
|
||||
int wid = threadIdx.x >> 5;
|
||||
|
||||
val = warpReduceSum<T>(val);
|
||||
|
||||
if (lane == 0)
|
||||
shared[wid] = val;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
|
||||
// blockDim.x is not divided by 32
|
||||
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
|
||||
val = warpReduceSum<T>(val);
|
||||
|
||||
return val;
|
||||
}
|
||||
|
||||
} // namespace cacheflow
|
||||
8
setup.py
8
setup.py
@ -31,6 +31,14 @@ positional_encoding_extension = cpp_extension.CUDAExtension(
|
||||
)
|
||||
ext_modules.append(positional_encoding_extension)
|
||||
|
||||
# Layer normalization kernels.
|
||||
layernorm_extension = cpp_extension.CUDAExtension(
|
||||
name='cacheflow.layernorm_ops',
|
||||
sources=['csrc/layernorm.cpp', 'csrc/layernorm_kernels.cu'],
|
||||
extra_compile_args={'cxx': CXX_FLAGS, 'nvcc': NVCC_FLAGS},
|
||||
)
|
||||
ext_modules.append(layernorm_extension)
|
||||
|
||||
setuptools.setup(
|
||||
name='cacheflow',
|
||||
ext_modules=ext_modules,
|
||||
|
||||
53
tests/kernels/layernorm.py
Normal file
53
tests/kernels/layernorm.py
Normal file
@ -0,0 +1,53 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from cacheflow import layernorm_ops
|
||||
|
||||
|
||||
class RefRMSNorm(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
super().__init__()
|
||||
weight = torch.randn(hidden_size) / (hidden_size ** 0.5)
|
||||
self.weight = nn.Parameter(weight)
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
if self.weight.dtype in [torch.half, torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
return self.weight * hidden_states
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_rms_norm(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype, device='cuda')
|
||||
ref = RefRMSNorm(hidden_size).to(dtype).cuda()
|
||||
|
||||
out = torch.empty_like(x)
|
||||
layernorm_ops.rms_norm(
|
||||
out,
|
||||
x,
|
||||
ref.weight.data,
|
||||
ref.variance_epsilon,
|
||||
)
|
||||
ref_out = ref(x)
|
||||
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-5)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
for dtype in [torch.half, torch.float]:
|
||||
for num_tokens in [7, 128, 2048]:
|
||||
for hidden_size in [13, 64, 1024, 5120]:
|
||||
print(f'Testing RMS kernel with dtype={dtype}, num_tokens='
|
||||
f'{num_tokens}, hidden_size={hidden_size}')
|
||||
test_rms_norm(
|
||||
num_tokens=num_tokens,
|
||||
hidden_size=hidden_size,
|
||||
dtype=dtype,
|
||||
)
|
||||
Loading…
x
Reference in New Issue
Block a user