[CI/Build][Kernel][AMD] Move extra dim to after load in _fwd_kv_parallel in lighting_attn.py (#29132)

Signed-off-by: Randall Smith <ransmith@amd.com>
Co-authored-by: Randall Smith <ransmith@amd.com>
This commit is contained in:
rasmith 2025-11-21 10:53:09 -06:00 committed by GitHub
parent a42ab317ac
commit e99e467384
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -198,7 +198,7 @@ def _fwd_kv_parallel(
)
# Load the decay factors for the current head and block
k_decay_ptr = K_decay + off_h * BLOCK + tl.arange(0, CBLOCK)[None, :]
k_decay_ptr = K_decay + off_h * BLOCK + tl.arange(0, CBLOCK)
kv_index = tl.arange(0, CBLOCK)
@ -228,6 +228,12 @@ def _fwd_kv_parallel(
# Load decay factor and compute weighted key-value outer product
k_decay = tl.load(k_decay_ptr)
# NOTE: Need to add the extra dim here due to AMD MLIR lowering error.
# Please don't move it back until issue is resolved.
# Issue: https://github.com/ROCm/triton/issues/907
k_decay = k_decay[None, :]
kv += tl.dot(k_trans * k_decay, v)
# Move to the next sub-block