From 7a1c4025f1e2879fa398888d70c596e5818026cb Mon Sep 17 00:00:00 2001 From: Ignacio Sica Date: Fri, 12 Sep 2025 08:23:07 -0300 Subject: [PATCH] [Kernel] [CPU] refactor `cpu_attn.py:_run_sdpa_forward` for better memory access (#24701) Signed-off-by: ignaciosica --- vllm/v1/attention/backends/cpu_attn.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index ab87f3bb4e3c..6627164c9879 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -641,10 +641,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): attn_metadata: TorchSDPAMetadata, attn_type: str = AttentionType.DECODER, ) -> None: - if self.num_kv_heads != self.num_heads: - key = key.repeat_interleave(self.num_queries_per_kv, dim=1) - value = value.repeat_interleave(self.num_queries_per_kv, dim=1) - attn_masks = attn_metadata.get_attn_bias(attn_type) if attn_masks is None: if self.alibi_slopes is not None: @@ -665,6 +661,10 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): key = key.movedim(0, key.dim() - 2) value = value.movedim(0, value.dim() - 2) + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=-3) + value = value.repeat_interleave(self.num_queries_per_kv, dim=-3) + causal_attn = (attn_type == AttentionType.DECODER) seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type)