[Bugfix][Hardware][AMD] Fix uninitialized Qlocal registers in ROCm attention kernel

In the ROCm PagedAttention wmma kernel, when GQA_RATIO == 1, only lane 0
loads valid Q data into the Qlocal registers. Lanes 1-15 retain garbage
values from previous GPU cycles. These uninitialized values then
contaminate the wmma (Wave Matrix Multiply-Accumulate) instruction results,
causing subtle numerical accuracy issues.

The bug exists in two locations within paged_attention_ll4_kv_kernel:
1. Lines 1834-1842: _B16x16 Qlocal for 16-bit cache types
2. Lines 2610-2617: _B16x8 Qlocal for 8-bit cache types

Both locations have an `if (lane16id < GQA_RATIO)` block that loads Q data
but lack an `else` clause to zero out Qlocal for non-loading lanes.

The correct pattern already exists elsewhere in the file (lines 1067-1070)
where unused Qlocal slots are explicitly zeroed:
```cpp
} else {
  Qlocal[QHLOOP - 1].xy[0] = {0};
  Qlocal[QHLOOP - 1].xy[1] = {0};
}
```

This fix adds the missing `else` clauses to zero out Qlocal registers for
lanes that don't load Q data, preventing garbage values from propagating
into the attention score computation.

Impact:
- Affects non-GQA models (GQA_RATIO == 1) like Llama-2
- Symptom: Random numerical drift, potential NaNs in softmax
- Fix ensures deterministic behavior across all wave lanes

Signed-off-by: c0de128 <kevin.mckay@outlook.com>
This commit is contained in:
c0de128 2025-12-24 08:45:44 -06:00
parent d201807339
commit 348de41b52

View File

@ -1839,6 +1839,13 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
reinterpret_cast<const _B16x16*>(q_fetch_ptr);
Qlocal[qkhe_depth] = *q_fetch_ptr_32B;
}
} else {
// Zero out Qlocal for lanes that don't load Q data to prevent
// uninitialized register values from contaminating wmma results
#pragma unroll
for (int qkhe_depth = 0; qkhe_depth < QKHELOOP / 2; qkhe_depth++) {
Qlocal[qkhe_depth] = {};
}
}
} else {
// fetch Q in shared across warps and then write to registers
@ -2608,6 +2615,13 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
reinterpret_cast<const _B16x8*>(q_fetch_ptr);
Qlocal[qkhe_depth] = *q_fetch_ptr_16B;
}
} else {
// Zero out Qlocal for lanes that don't load Q data to prevent
// uninitialized register values from contaminating wmma results
#pragma unroll
for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) {
Qlocal[qkhe_depth] = {};
}
}
} else {
// fetch Q in shared across warps and then write to registers