From 348de41b5253318b4a578d7d972511aebb50f7dc Mon Sep 17 00:00:00 2001 From: c0de128 Date: Wed, 24 Dec 2025 08:45:44 -0600 Subject: [PATCH] [Bugfix][Hardware][AMD] Fix uninitialized Qlocal registers in ROCm attention kernel In the ROCm PagedAttention wmma kernel, when GQA_RATIO == 1, only lane 0 loads valid Q data into the Qlocal registers. Lanes 1-15 retain garbage values from previous GPU cycles. These uninitialized values then contaminate the wmma (Wave Matrix Multiply-Accumulate) instruction results, causing subtle numerical accuracy issues. The bug exists in two locations within paged_attention_ll4_kv_kernel: 1. Lines 1834-1842: _B16x16 Qlocal for 16-bit cache types 2. Lines 2610-2617: _B16x8 Qlocal for 8-bit cache types Both locations have an `if (lane16id < GQA_RATIO)` block that loads Q data but lack an `else` clause to zero out Qlocal for non-loading lanes. The correct pattern already exists elsewhere in the file (lines 1067-1070) where unused Qlocal slots are explicitly zeroed: ```cpp } else { Qlocal[QHLOOP - 1].xy[0] = {0}; Qlocal[QHLOOP - 1].xy[1] = {0}; } ``` This fix adds the missing `else` clauses to zero out Qlocal registers for lanes that don't load Q data, preventing garbage values from propagating into the attention score computation. Impact: - Affects non-GQA models (GQA_RATIO == 1) like Llama-2 - Symptom: Random numerical drift, potential NaNs in softmax - Fix ensures deterministic behavior across all wave lanes Signed-off-by: c0de128 --- csrc/rocm/attention.cu | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index a339c5641bb4a..78a8daebfcffb 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -1839,6 +1839,13 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( reinterpret_cast(q_fetch_ptr); Qlocal[qkhe_depth] = *q_fetch_ptr_32B; } + } else { + // Zero out Qlocal for lanes that don't load Q data to prevent + // uninitialized register values from contaminating wmma results + #pragma unroll + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP / 2; qkhe_depth++) { + Qlocal[qkhe_depth] = {}; + } } } else { // fetch Q in shared across warps and then write to registers @@ -2608,6 +2615,13 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( reinterpret_cast(q_fetch_ptr); Qlocal[qkhe_depth] = *q_fetch_ptr_16B; } + } else { + // Zero out Qlocal for lanes that don't load Q data to prevent + // uninitialized register values from contaminating wmma results + #pragma unroll + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + Qlocal[qkhe_depth] = {}; + } } } else { // fetch Q in shared across warps and then write to registers