diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index a339c5641bb4a..78a8daebfcffb 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -1839,6 +1839,13 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( reinterpret_cast(q_fetch_ptr); Qlocal[qkhe_depth] = *q_fetch_ptr_32B; } + } else { + // Zero out Qlocal for lanes that don't load Q data to prevent + // uninitialized register values from contaminating wmma results + #pragma unroll + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP / 2; qkhe_depth++) { + Qlocal[qkhe_depth] = {}; + } } } else { // fetch Q in shared across warps and then write to registers @@ -2608,6 +2615,13 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( reinterpret_cast(q_fetch_ptr); Qlocal[qkhe_depth] = *q_fetch_ptr_16B; } + } else { + // Zero out Qlocal for lanes that don't load Q data to prevent + // uninitialized register values from contaminating wmma results + #pragma unroll + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + Qlocal[qkhe_depth] = {}; + } } } else { // fetch Q in shared across warps and then write to registers