mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-24 16:15:45 +08:00
[Bugfix][Hardware][AMD] Fix uninitialized Qlocal registers in ROCm attention kernel
In the ROCm PagedAttention wmma kernel, when GQA_RATIO == 1, only lane 0
loads valid Q data into the Qlocal registers. Lanes 1-15 retain garbage
values from previous GPU cycles. These uninitialized values then
contaminate the wmma (Wave Matrix Multiply-Accumulate) instruction results,
causing subtle numerical accuracy issues.
The bug exists in two locations within paged_attention_ll4_kv_kernel:
1. Lines 1834-1842: _B16x16 Qlocal for 16-bit cache types
2. Lines 2610-2617: _B16x8 Qlocal for 8-bit cache types
Both locations have an `if (lane16id < GQA_RATIO)` block that loads Q data
but lack an `else` clause to zero out Qlocal for non-loading lanes.
The correct pattern already exists elsewhere in the file (lines 1067-1070)
where unused Qlocal slots are explicitly zeroed:
```cpp
} else {
Qlocal[QHLOOP - 1].xy[0] = {0};
Qlocal[QHLOOP - 1].xy[1] = {0};
}
```
This fix adds the missing `else` clauses to zero out Qlocal registers for
lanes that don't load Q data, preventing garbage values from propagating
into the attention score computation.
Impact:
- Affects non-GQA models (GQA_RATIO == 1) like Llama-2
- Symptom: Random numerical drift, potential NaNs in softmax
- Fix ensures deterministic behavior across all wave lanes
Signed-off-by: c0de128 <kevin.mckay@outlook.com>
This commit is contained in:
parent
d201807339
commit
348de41b52
@ -1839,6 +1839,13 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
reinterpret_cast<const _B16x16*>(q_fetch_ptr);
|
||||
Qlocal[qkhe_depth] = *q_fetch_ptr_32B;
|
||||
}
|
||||
} else {
|
||||
// Zero out Qlocal for lanes that don't load Q data to prevent
|
||||
// uninitialized register values from contaminating wmma results
|
||||
#pragma unroll
|
||||
for (int qkhe_depth = 0; qkhe_depth < QKHELOOP / 2; qkhe_depth++) {
|
||||
Qlocal[qkhe_depth] = {};
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// fetch Q in shared across warps and then write to registers
|
||||
@ -2608,6 +2615,13 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
reinterpret_cast<const _B16x8*>(q_fetch_ptr);
|
||||
Qlocal[qkhe_depth] = *q_fetch_ptr_16B;
|
||||
}
|
||||
} else {
|
||||
// Zero out Qlocal for lanes that don't load Q data to prevent
|
||||
// uninitialized register values from contaminating wmma results
|
||||
#pragma unroll
|
||||
for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) {
|
||||
Qlocal[qkhe_depth] = {};
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// fetch Q in shared across warps and then write to registers
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user