From e99e467384001e284e0722a33362866b10fed65b Mon Sep 17 00:00:00 2001 From: rasmith Date: Fri, 21 Nov 2025 10:53:09 -0600 Subject: [PATCH] [CI/Build][Kernel][AMD] Move extra dim to after load in _fwd_kv_parallel in lighting_attn.py (#29132) Signed-off-by: Randall Smith Co-authored-by: Randall Smith --- vllm/model_executor/layers/lightning_attn.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/lightning_attn.py b/vllm/model_executor/layers/lightning_attn.py index 99853680eac6..ffccdc12241c 100644 --- a/vllm/model_executor/layers/lightning_attn.py +++ b/vllm/model_executor/layers/lightning_attn.py @@ -198,7 +198,7 @@ def _fwd_kv_parallel( ) # Load the decay factors for the current head and block - k_decay_ptr = K_decay + off_h * BLOCK + tl.arange(0, CBLOCK)[None, :] + k_decay_ptr = K_decay + off_h * BLOCK + tl.arange(0, CBLOCK) kv_index = tl.arange(0, CBLOCK) @@ -228,6 +228,12 @@ def _fwd_kv_parallel( # Load decay factor and compute weighted key-value outer product k_decay = tl.load(k_decay_ptr) + + # NOTE: Need to add the extra dim here due to AMD MLIR lowering error. + # Please don't move it back until issue is resolved. + # Issue: https://github.com/ROCm/triton/issues/907 + k_decay = k_decay[None, :] + kv += tl.dot(k_trans * k_decay, v) # Move to the next sub-block