diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 6c3685f6f7cdc..aa7927f09cbbf 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -2,6 +2,7 @@ #include "dispatch_utils.h" #include "cub_helpers.h" #include "core/batch_invariant.hpp" +#include "quantization/vectorization_utils.cuh" #include #include @@ -18,11 +19,22 @@ __global__ void rms_norm_kernel( const float epsilon, const int num_tokens, const int hidden_size) { __shared__ float s_variance; float variance = 0.0f; + const scalar_t* input_row = input + blockIdx.x * input_stride; - for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - const float x = (float)input[blockIdx.x * input_stride + idx]; + constexpr int VEC_SIZE = 8; + auto vec_op = [&variance](const vec_n_t& vec) { +#pragma unroll + for (int i = 0; i < VEC_SIZE; ++i) { + float x = static_cast(vec.val[i]); + variance += x * x; + } + }; + auto scalar_op = [&variance](const scalar_t& val) { + float x = static_cast(val); variance += x * x; - } + }; + vllm::vectorize_read_with_alignment( + input_row, hidden_size, threadIdx.x, blockDim.x, vec_op, scalar_op); using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore; diff --git a/csrc/layernorm_quant_kernels.cu b/csrc/layernorm_quant_kernels.cu index 0fc462194fcde..7f9a0bccdd348 100644 --- a/csrc/layernorm_quant_kernels.cu +++ b/csrc/layernorm_quant_kernels.cu @@ -10,6 +10,7 @@ #include "dispatch_utils.h" #include "cub_helpers.h" #include "core/batch_invariant.hpp" +#include "quantization/vectorization_utils.cuh" #include #include @@ -28,10 +29,22 @@ __global__ void rms_norm_static_fp8_quant_kernel( __shared__ float s_variance; float variance = 0.0f; - for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - const float x = (float)input[blockIdx.x * input_stride + idx]; + const scalar_t* input_row = input + blockIdx.x * input_stride; + + constexpr int VEC_SIZE = 8; + auto vec_op = [&variance](const vec_n_t& vec) { +#pragma unroll + for (int i = 0; i < VEC_SIZE; ++i) { + float x = static_cast(vec.val[i]); + variance += x * x; + } + }; + auto scalar_op = [&variance](const scalar_t& val) { + float x = static_cast(val); variance += x * x; - } + }; + vllm::vectorize_read_with_alignment( + input_row, hidden_size, threadIdx.x, blockDim.x, vec_op, scalar_op); using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage reduceStore;