mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 19:55:01 +08:00
[CI/Build][Kernel][AMD] Move extra dim to after load in _fwd_kv_parallel in lighting_attn.py (#29132)
Signed-off-by: Randall Smith <ransmith@amd.com> Co-authored-by: Randall Smith <ransmith@amd.com>
This commit is contained in:
parent
a42ab317ac
commit
e99e467384
@ -198,7 +198,7 @@ def _fwd_kv_parallel(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Load the decay factors for the current head and block
|
# Load the decay factors for the current head and block
|
||||||
k_decay_ptr = K_decay + off_h * BLOCK + tl.arange(0, CBLOCK)[None, :]
|
k_decay_ptr = K_decay + off_h * BLOCK + tl.arange(0, CBLOCK)
|
||||||
|
|
||||||
kv_index = tl.arange(0, CBLOCK)
|
kv_index = tl.arange(0, CBLOCK)
|
||||||
|
|
||||||
@ -228,6 +228,12 @@ def _fwd_kv_parallel(
|
|||||||
|
|
||||||
# Load decay factor and compute weighted key-value outer product
|
# Load decay factor and compute weighted key-value outer product
|
||||||
k_decay = tl.load(k_decay_ptr)
|
k_decay = tl.load(k_decay_ptr)
|
||||||
|
|
||||||
|
# NOTE: Need to add the extra dim here due to AMD MLIR lowering error.
|
||||||
|
# Please don't move it back until issue is resolved.
|
||||||
|
# Issue: https://github.com/ROCm/triton/issues/907
|
||||||
|
k_decay = k_decay[None, :]
|
||||||
|
|
||||||
kv += tl.dot(k_trans * k_decay, v)
|
kv += tl.dot(k_trans * k_decay, v)
|
||||||
|
|
||||||
# Move to the next sub-block
|
# Move to the next sub-block
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user