mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-26 12:44:30 +08:00
Vectorize RMS norm variance using vectorize_read_with_alignment (#26234)
Signed-off-by: Benji Beck <benjibeck@meta.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
parent
de92d916fe
commit
1f491aa0c8
@ -2,6 +2,7 @@
|
||||
#include "dispatch_utils.h"
|
||||
#include "cub_helpers.h"
|
||||
#include "core/batch_invariant.hpp"
|
||||
#include "quantization/vectorization_utils.cuh"
|
||||
|
||||
#include <torch/cuda.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
@ -18,11 +19,22 @@ __global__ void rms_norm_kernel(
|
||||
const float epsilon, const int num_tokens, const int hidden_size) {
|
||||
__shared__ float s_variance;
|
||||
float variance = 0.0f;
|
||||
const scalar_t* input_row = input + blockIdx.x * input_stride;
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
const float x = (float)input[blockIdx.x * input_stride + idx];
|
||||
constexpr int VEC_SIZE = 8;
|
||||
auto vec_op = [&variance](const vec_n_t<scalar_t, VEC_SIZE>& vec) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VEC_SIZE; ++i) {
|
||||
float x = static_cast<float>(vec.val[i]);
|
||||
variance += x * x;
|
||||
}
|
||||
};
|
||||
auto scalar_op = [&variance](const scalar_t& val) {
|
||||
float x = static_cast<float>(val);
|
||||
variance += x * x;
|
||||
}
|
||||
};
|
||||
vllm::vectorize_read_with_alignment<VEC_SIZE>(
|
||||
input_row, hidden_size, threadIdx.x, blockDim.x, vec_op, scalar_op);
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
|
||||
@ -10,6 +10,7 @@
|
||||
#include "dispatch_utils.h"
|
||||
#include "cub_helpers.h"
|
||||
#include "core/batch_invariant.hpp"
|
||||
#include "quantization/vectorization_utils.cuh"
|
||||
|
||||
#include <torch/cuda.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
@ -28,10 +29,22 @@ __global__ void rms_norm_static_fp8_quant_kernel(
|
||||
__shared__ float s_variance;
|
||||
float variance = 0.0f;
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
const float x = (float)input[blockIdx.x * input_stride + idx];
|
||||
const scalar_t* input_row = input + blockIdx.x * input_stride;
|
||||
|
||||
constexpr int VEC_SIZE = 8;
|
||||
auto vec_op = [&variance](const vec_n_t<scalar_t, VEC_SIZE>& vec) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VEC_SIZE; ++i) {
|
||||
float x = static_cast<float>(vec.val[i]);
|
||||
variance += x * x;
|
||||
}
|
||||
};
|
||||
auto scalar_op = [&variance](const scalar_t& val) {
|
||||
float x = static_cast<float>(val);
|
||||
variance += x * x;
|
||||
}
|
||||
};
|
||||
vllm::vectorize_read_with_alignment<VEC_SIZE>(
|
||||
input_row, hidden_size, threadIdx.x, blockDim.x, vec_op, scalar_op);
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user