From b10c64c8344693500bab14d19ef246c8299b2e15 Mon Sep 17 00:00:00 2001 From: rasmith Date: Fri, 17 Oct 2025 13:17:18 -0500 Subject: [PATCH] [ROCm][Bugfix][Model] Fix illegal memory access when running qwen3_moe models with rms_norm (Qwen3-235B-A22B, Qwen3-30B-A3B, etc.) (#26192) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Randall Smith Signed-off-by: Randall Smith Signed-off-by: rasmith Co-authored-by: Randall Smith Co-authored-by: Luka Govedič --- csrc/layernorm_kernels.cu | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 732ed4760f35..3a8f9bc3b5a6 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -364,18 +364,26 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size] TORCH_CHECK(weight.is_contiguous()); int hidden_size = input.size(-1); - int num_tokens = input.numel() / hidden_size; - int64_t input_stride = input.stride(-2); + + // We cannot just use `input.stride(-2)` if the tensor is not row-major. + // Instead, we use a 2d view to get the second-innermost stride. + // That way the dimensions (except the last one) can be arbitrarily permuted. + torch::Tensor input_view = input.view({-1, hidden_size}); + + int num_tokens = input_view.numel() / hidden_size; + int64_t input_stride = input_view.stride(-2); dim3 grid(num_tokens); dim3 block(std::min(hidden_size, 1024)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input_view)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { - vllm::rms_norm_kernel<<>>( - out.data_ptr(), input.data_ptr(), input_stride, - weight.data_ptr(), epsilon, num_tokens, hidden_size); - }); + VLLM_DISPATCH_FLOATING_TYPES( + input_view.scalar_type(), "rms_norm_kernel", [&] { + vllm::rms_norm_kernel<<>>( + out.data_ptr(), input_view.data_ptr(), + input_stride, weight.data_ptr(), epsilon, num_tokens, + hidden_size); + }); } #define LAUNCH_FUSED_ADD_RMS_NORM(width) \