[Kernel] [CPU] refactor cpu_attn.py:_run_sdpa_forward for better memory access (#24701)

Signed-off-by: ignaciosica <mignacio.sica@gmail.com>
This commit is contained in:
Ignacio Sica 2025-09-12 08:23:07 -03:00 committed by GitHub
parent 60a0951924
commit 7a1c4025f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -641,10 +641,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
attn_metadata: TorchSDPAMetadata,
attn_type: str = AttentionType.DECODER,
) -> None:
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
value = value.repeat_interleave(self.num_queries_per_kv, dim=1)
attn_masks = attn_metadata.get_attn_bias(attn_type)
if attn_masks is None:
if self.alibi_slopes is not None:
@ -665,6 +661,10 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2)
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=-3)
value = value.repeat_interleave(self.num_queries_per_kv, dim=-3)
causal_attn = (attn_type == AttentionType.DECODER)
seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type)