diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 732ed4760f35..3a8f9bc3b5a6 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -364,18 +364,26 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size] TORCH_CHECK(weight.is_contiguous()); int hidden_size = input.size(-1); - int num_tokens = input.numel() / hidden_size; - int64_t input_stride = input.stride(-2); + + // We cannot just use `input.stride(-2)` if the tensor is not row-major. + // Instead, we use a 2d view to get the second-innermost stride. + // That way the dimensions (except the last one) can be arbitrarily permuted. + torch::Tensor input_view = input.view({-1, hidden_size}); + + int num_tokens = input_view.numel() / hidden_size; + int64_t input_stride = input_view.stride(-2); dim3 grid(num_tokens); dim3 block(std::min(hidden_size, 1024)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input_view)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { - vllm::rms_norm_kernel<<>>( - out.data_ptr(), input.data_ptr(), input_stride, - weight.data_ptr(), epsilon, num_tokens, hidden_size); - }); + VLLM_DISPATCH_FLOATING_TYPES( + input_view.scalar_type(), "rms_norm_kernel", [&] { + vllm::rms_norm_kernel<<>>( + out.data_ptr(), input_view.data_ptr(), + input_stride, weight.data_ptr(), epsilon, num_tokens, + hidden_size); + }); } #define LAUNCH_FUSED_ADD_RMS_NORM(width) \