mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:44:57 +08:00
[CI/Build][Kernel][AMD] Move extra dim to after load in _fwd_kv_parallel in lighting_attn.py (#29132)
Signed-off-by: Randall Smith <ransmith@amd.com> Co-authored-by: Randall Smith <ransmith@amd.com>
This commit is contained in:
parent
a42ab317ac
commit
e99e467384
@ -198,7 +198,7 @@ def _fwd_kv_parallel(
|
||||
)
|
||||
|
||||
# Load the decay factors for the current head and block
|
||||
k_decay_ptr = K_decay + off_h * BLOCK + tl.arange(0, CBLOCK)[None, :]
|
||||
k_decay_ptr = K_decay + off_h * BLOCK + tl.arange(0, CBLOCK)
|
||||
|
||||
kv_index = tl.arange(0, CBLOCK)
|
||||
|
||||
@ -228,6 +228,12 @@ def _fwd_kv_parallel(
|
||||
|
||||
# Load decay factor and compute weighted key-value outer product
|
||||
k_decay = tl.load(k_decay_ptr)
|
||||
|
||||
# NOTE: Need to add the extra dim here due to AMD MLIR lowering error.
|
||||
# Please don't move it back until issue is resolved.
|
||||
# Issue: https://github.com/ROCm/triton/issues/907
|
||||
k_decay = k_decay[None, :]
|
||||
|
||||
kv += tl.dot(k_trans * k_decay, v)
|
||||
|
||||
# Move to the next sub-block
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user