[Misc] [ROCm] Prevent surplus tensor reshape (#19803)

Signed-off-by: Zsolt Borbely <zsolt.borbely@htecgroup.com>
This commit is contained in:
zsolt-borbely-htec 2025-06-19 07:57:16 +02:00 committed by GitHub
parent 2de12be428
commit aa20d10a91
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -376,7 +376,7 @@ class TritonAttentionImpl(AttentionImpl):
query.reshape(
(num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale)
query = query.reshape((num_tokens, num_heads, head_size))
query = query.reshape((num_tokens, num_heads, head_size))
use_local_attn = \
(self.use_irope and attn_metadata.local_attn_metadata is not None)