[fix][cpu] fix prefill attention in CPU attention backend (#27035)

Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
This commit is contained in:
Fadi Arafeh 2025-10-18 14:30:21 +01:00 committed by GitHub
parent 245e4f2c01
commit ab4be40fc5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 16 additions and 4 deletions

View File

@ -1293,7 +1293,8 @@ class EngineArgs:
# Set default arguments for V1 Engine.
self._set_default_args(usage_context, model_config)
# Disable chunked prefill for POWER (ppc64le)/ARM/s390x/RISCV CPUs in V1
# Disable chunked prefill and prefix caching for:
# POWER (ppc64le)/ARM/s390x/RISCV CPUs in V1
if current_platform.is_cpu() and current_platform.get_cpu_architecture() in (
CpuArchEnum.POWERPC,
CpuArchEnum.S390X,
@ -1306,6 +1307,13 @@ class EngineArgs:
"disabling it for V1 backend."
)
self.enable_chunked_prefill = False
logger.info(
"Prefix caching is not supported for ARM and POWER, "
"S390X and RISC-V CPUs; "
"disabling it for V1 backend."
)
self.enable_prefix_caching = False
assert self.enable_chunked_prefill is not None
sliding_window: int | None = None

View File

@ -412,7 +412,7 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
num_decode_tokens=num_decode_tokens,
slot_mapping=slot_mapping,
# to ensure inference when chunked_prefill is disabled
seq_lens=seq_lens_cpu.tolist(),
seq_lens=seq_lens_cpu.tolist()[num_decodes:], # prefill
decode_seq_lens_tensor=seq_lens_cpu[:num_decodes], # decode
decode_max_seq_len=max_decode_seq_len, # decode
decode_block_tables=block_table_tensor[:num_decodes], # decode
@ -617,7 +617,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
prefill_meta.prefill_block_tables,
self.alibi_slopes,
)
if decode_meta := attn_metadata.decode_metadata:
assert attn_type != AttentionType.ENCODER_ONLY, (
"Encoder-only models should not have decode metadata."
@ -686,7 +685,12 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
causal_attn = attn_type == AttentionType.DECODER
seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type)
start_q, start_kv = 0, 0
# Incoming Q and KV contain decoded tokens as well, hence start at an offset
# equal to num_decode_tokens since decode requests appear first
start_q, start_kv = (
attn_metadata.num_decode_tokens,
attn_metadata.num_decode_tokens,
)
for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv, attn_masks):
end_q = start_q + seq_len_q
end_kv = start_kv + seq_len_kv