diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index 9ae0ed975edd..e1d131e4a785 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -117,3 +117,24 @@ break; \ } \ } + +#define VLLM_DISPATCH_RANK234(NUM_DIMS, ...) \ + switch (NUM_DIMS) { \ + case 2: { \ + constexpr int tensor_rank = 2; \ + __VA_ARGS__(); \ + break; \ + } \ + case 3: { \ + constexpr int tensor_rank = 3; \ + __VA_ARGS__(); \ + break; \ + } \ + case 4: { \ + constexpr int tensor_rank = 4; \ + __VA_ARGS__(); \ + break; \ + } \ + default: \ + TORCH_CHECK(false, "Expects rank 2, 3 or 4 tensors but got ", NUM_DIMS); \ + } diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 48771e4b3aff..dfc67b933cca 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -10,16 +10,38 @@ namespace vllm { // TODO(woosuk): Further optimize this kernel. -template +template __global__ void rms_norm_kernel( - scalar_t* __restrict__ out, // [..., hidden_size] - const scalar_t* __restrict__ input, // [..., hidden_size] - const int64_t input_stride, + scalar_t* __restrict__ out, // [..., hidden_size] + const scalar_t* __restrict__ input, // [..., hidden_size] + const int64_t input_stride_d2, // input.stride(-2) + const int64_t input_stride_d3, // input.stride(-3) + const int64_t input_stride_d4, // input.stride(-4) + const int64_t input_shape_d2, // input.size(-2) + const int64_t input_shape_d3, // input.size(-3) const scalar_t* __restrict__ weight, // [hidden_size] const float epsilon, const int num_tokens, const int hidden_size) { __shared__ float s_variance; float variance = 0.0f; - const scalar_t* input_row = input + blockIdx.x * input_stride; + const scalar_t* input_row; + if constexpr (NUM_DIMS == 2) { + // 2D for layernorm normal case [batch_size, hidden] + input_row = input + blockIdx.x * input_stride_d2; + } else if constexpr (NUM_DIMS == 3) { + // 3D for q/k norm [batch_size, num_heads, head_size] + int batch_idx = blockIdx.x / input_shape_d2; + int head_idx = blockIdx.x % input_shape_d2; + input_row = + input + batch_idx * input_stride_d3 + head_idx * input_stride_d2; + } else if constexpr (NUM_DIMS == 4) { + // 4D for transformers model_impl qk norm [batch, seq, head, head_dim] + int batch_idx = blockIdx.x / (input_shape_d3 * input_shape_d2); + int remaining = blockIdx.x % (input_shape_d3 * input_shape_d2); + int seq_idx = remaining / input_shape_d2; + int head_idx = remaining % input_shape_d2; + input_row = input + batch_idx * input_stride_d4 + + seq_idx * input_stride_d3 + head_idx * input_stride_d2; + } auto vec_op = [&variance](const vec_n_t& vec) { #pragma unroll @@ -164,38 +186,44 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size] torch::Tensor& weight, // [hidden_size] double epsilon) { TORCH_CHECK(out.is_contiguous()); + if (input.stride(-1) != 1) { + input = input.contiguous(); + } TORCH_CHECK(input.stride(-1) == 1); TORCH_CHECK(weight.is_contiguous()); int hidden_size = input.size(-1); - // We cannot just use `input.stride(-2)` if the tensor is not row-major. - // Instead, we use a 2d view to get the second-innermost stride. - // That way the dimensions (except the last one) can be arbitrarily permuted. - torch::Tensor input_view = input.view({-1, hidden_size}); - - int num_tokens = input_view.numel() / hidden_size; - int64_t input_stride = input_view.stride(-2); + int num_tokens = input.numel() / hidden_size; + int num_dims = input.dim(); + int64_t input_stride_d2 = input.stride(-2); + int64_t input_stride_d3 = (num_dims >= 3) ? input.stride(-3) : 0; + int64_t input_stride_d4 = (num_dims >= 4) ? input.stride(-4) : 0; + int64_t input_shape_d2 = (num_dims >= 3) ? input.size(-2) : 0; + int64_t input_shape_d3 = (num_dims >= 4) ? input.size(-3) : 0; // For large num_tokens, use smaller blocks to increase SM concurrency. const int max_block_size = (num_tokens < 256) ? 1024 : 256; dim3 grid(num_tokens); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input_view)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( - input_view.scalar_type(), "rms_norm_kernel", [&] { - const int calculated_vec_size = - std::gcd(16 / sizeof(scalar_t), hidden_size); - const int block_size = - std::min(hidden_size / calculated_vec_size, max_block_size); - dim3 block(block_size); - VLLM_DISPATCH_VEC_SIZE(calculated_vec_size, [&] { - vllm::rms_norm_kernel<<>>( - out.data_ptr(), input_view.data_ptr(), - input_stride, weight.data_ptr(), epsilon, num_tokens, - hidden_size); - }); + VLLM_DISPATCH_RANK234(num_dims, [&] { + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { + const int calculated_vec_size = + std::gcd(16 / sizeof(scalar_t), hidden_size); + const int block_size = + std::min(hidden_size / calculated_vec_size, max_block_size); + dim3 block(block_size); + VLLM_DISPATCH_VEC_SIZE(calculated_vec_size, [&] { + vllm::rms_norm_kernel + <<>>( + out.data_ptr(), input.data_ptr(), + input_stride_d2, input_stride_d3, input_stride_d4, + input_shape_d2, input_shape_d3, weight.data_ptr(), + epsilon, num_tokens, hidden_size); }); + }); + }); } #define LAUNCH_FUSED_ADD_RMS_NORM(width) \ diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 66cf6472eee4..0f625a794524 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -328,10 +328,7 @@ def rotary_embedding( def rms_norm( out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, epsilon: float ) -> None: - # TODO: Remove this contiguous call when the kernel is updated to support non-contiguous input - # If removed, also need to remove contiguous in MatcherRMSNorm - input_contiguous = input.contiguous() - torch.ops._C.rms_norm(out, input_contiguous, weight, epsilon) + torch.ops._C.rms_norm(out, input, weight, epsilon) def fused_add_rms_norm( diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index 38eb4e5301a1..e4cd063d2aee 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -162,12 +162,10 @@ class MatcherRMSNorm(MatcherCustomOp): weight: torch.Tensor, ) -> torch.Tensor: result = torch.empty_like(input) - # TODO: support non-contiguous input for RMSNorm and remove this - input_contiguous = input.contiguous() _, result = auto_functionalized( RMS_OP, result=result, - input=input_contiguous, + input=input, weight=weight, epsilon=self.epsilon, )