diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index 995374a50b037..9ae0ed975edde 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -88,3 +88,32 @@ #define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH( \ TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__)) + +#define VLLM_DISPATCH_VEC_SIZE(VEC_SIZE, ...) \ + switch (VEC_SIZE) { \ + case 16: { \ + constexpr int vec_size = 16; \ + __VA_ARGS__(); \ + break; \ + } \ + case 8: { \ + constexpr int vec_size = 8; \ + __VA_ARGS__(); \ + break; \ + } \ + case 4: { \ + constexpr int vec_size = 4; \ + __VA_ARGS__(); \ + break; \ + } \ + case 2: { \ + constexpr int vec_size = 2; \ + __VA_ARGS__(); \ + break; \ + } \ + default: { \ + constexpr int vec_size = 1; \ + __VA_ARGS__(); \ + break; \ + } \ + } diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 8cfcf9f41283a..48771e4b3aff9 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -10,7 +10,7 @@ namespace vllm { // TODO(woosuk): Further optimize this kernel. -template +template __global__ void rms_norm_kernel( scalar_t* __restrict__ out, // [..., hidden_size] const scalar_t* __restrict__ input, // [..., hidden_size] @@ -21,7 +21,6 @@ __global__ void rms_norm_kernel( float variance = 0.0f; const scalar_t* input_row = input + blockIdx.x * input_stride; - constexpr int VEC_SIZE = 8; auto vec_op = [&variance](const vec_n_t& vec) { #pragma unroll for (int i = 0; i < VEC_SIZE; ++i) { @@ -45,10 +44,20 @@ __global__ void rms_norm_kernel( } __syncthreads(); - for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float x = (float)input[blockIdx.x * input_stride + idx]; - out[blockIdx.x * hidden_size + idx] = - ((scalar_t)(x * s_variance)) * weight[idx]; + scalar_t* out_row = out + blockIdx.x * hidden_size; + auto* v_in = reinterpret_cast*>(input_row); + auto* v_w = reinterpret_cast*>(weight); + auto* v_out = reinterpret_cast*>(out_row); + for (int i = threadIdx.x; i < hidden_size / VEC_SIZE; i += blockDim.x) { + vec_n_t dst; + vec_n_t src1 = v_in[i]; + vec_n_t src2 = v_w[i]; +#pragma unroll + for (int j = 0; j < VEC_SIZE; j++) { + float x = static_cast(src1.val[j]); + dst.val[j] = ((scalar_t)(x * s_variance)) * src2.val[j]; + } + v_out[i] = dst; } } @@ -168,16 +177,24 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size] int num_tokens = input_view.numel() / hidden_size; int64_t input_stride = input_view.stride(-2); + // For large num_tokens, use smaller blocks to increase SM concurrency. + const int max_block_size = (num_tokens < 256) ? 1024 : 256; dim3 grid(num_tokens); - dim3 block(std::min(hidden_size, 1024)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input_view)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input_view.scalar_type(), "rms_norm_kernel", [&] { - vllm::rms_norm_kernel<<>>( - out.data_ptr(), input_view.data_ptr(), - input_stride, weight.data_ptr(), epsilon, num_tokens, - hidden_size); + const int calculated_vec_size = + std::gcd(16 / sizeof(scalar_t), hidden_size); + const int block_size = + std::min(hidden_size / calculated_vec_size, max_block_size); + dim3 block(block_size); + VLLM_DISPATCH_VEC_SIZE(calculated_vec_size, [&] { + vllm::rms_norm_kernel<<>>( + out.data_ptr(), input_view.data_ptr(), + input_stride, weight.data_ptr(), epsilon, num_tokens, + hidden_size); + }); }); } diff --git a/csrc/layernorm_quant_kernels.cu b/csrc/layernorm_quant_kernels.cu index 0f7f034ee180b..0880b8d50a795 100644 --- a/csrc/layernorm_quant_kernels.cu +++ b/csrc/layernorm_quant_kernels.cu @@ -18,7 +18,7 @@ namespace vllm { // TODO(woosuk): Further optimize this kernel. -template +template __global__ void rms_norm_static_fp8_quant_kernel( fp8_type* __restrict__ out, // [..., hidden_size] const scalar_t* __restrict__ input, // [..., hidden_size] @@ -31,7 +31,6 @@ __global__ void rms_norm_static_fp8_quant_kernel( const scalar_t* input_row = input + blockIdx.x * input_stride; - constexpr int VEC_SIZE = 8; auto vec_op = [&variance](const vec_n_t& vec) { #pragma unroll for (int i = 0; i < VEC_SIZE; ++i) { @@ -58,11 +57,18 @@ __global__ void rms_norm_static_fp8_quant_kernel( // invert scale to avoid division float const scale_inv = 1.0f / *scale; - for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float x = (float)input[blockIdx.x * input_stride + idx]; - float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx]; - out[blockIdx.x * hidden_size + idx] = - scaled_fp8_conversion(out_norm, scale_inv); + auto* v_in = reinterpret_cast*>(input_row); + auto* v_w = reinterpret_cast*>(weight); + for (int idx = threadIdx.x; idx < hidden_size / VEC_SIZE; idx += blockDim.x) { + vec_n_t src1 = v_in[idx]; + vec_n_t src2 = v_w[idx]; +#pragma unroll + for (int j = 0; j < VEC_SIZE; j++) { + float x = static_cast(src1.val[j]); + float const out_norm = ((scalar_t)(x * s_variance)) * src2.val[j]; + out[blockIdx.x * hidden_size + idx * VEC_SIZE + j] = + scaled_fp8_conversion(out_norm, scale_inv); + } } } @@ -188,20 +194,29 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size] int input_stride = input.stride(-2); int num_tokens = input.numel() / hidden_size; + // For large num_tokens, use smaller blocks to increase SM concurrency. + const int max_block_size = (num_tokens < 256) ? 1024 : 256; dim3 grid(num_tokens); - dim3 block(std::min(hidden_size, 1024)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "rms_norm_kernel_scalar_type", [&] { VLLM_DISPATCH_FP8_TYPES( out.scalar_type(), "rms_norm_kernel_fp8_type", [&] { - vllm::rms_norm_static_fp8_quant_kernel - <<>>( - out.data_ptr(), input.data_ptr(), - input_stride, weight.data_ptr(), - scale.data_ptr(), epsilon, num_tokens, - hidden_size); + const int calculated_vec_size = + std::gcd(16 / sizeof(scalar_t), hidden_size); + const int block_size = + std::min(hidden_size / calculated_vec_size, max_block_size); + dim3 block(block_size); + VLLM_DISPATCH_VEC_SIZE(calculated_vec_size, [&] { + vllm::rms_norm_static_fp8_quant_kernel + <<>>( + out.data_ptr(), input.data_ptr(), + input_stride, weight.data_ptr(), + scale.data_ptr(), epsilon, num_tokens, + hidden_size); + }); }); }); }