mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:45:00 +08:00
[fix][cpu] fix prefill attention in CPU attention backend (#27035)
Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
This commit is contained in:
parent
245e4f2c01
commit
ab4be40fc5
@ -1293,7 +1293,8 @@ class EngineArgs:
|
||||
|
||||
# Set default arguments for V1 Engine.
|
||||
self._set_default_args(usage_context, model_config)
|
||||
# Disable chunked prefill for POWER (ppc64le)/ARM/s390x/RISCV CPUs in V1
|
||||
# Disable chunked prefill and prefix caching for:
|
||||
# POWER (ppc64le)/ARM/s390x/RISCV CPUs in V1
|
||||
if current_platform.is_cpu() and current_platform.get_cpu_architecture() in (
|
||||
CpuArchEnum.POWERPC,
|
||||
CpuArchEnum.S390X,
|
||||
@ -1306,6 +1307,13 @@ class EngineArgs:
|
||||
"disabling it for V1 backend."
|
||||
)
|
||||
self.enable_chunked_prefill = False
|
||||
logger.info(
|
||||
"Prefix caching is not supported for ARM and POWER, "
|
||||
"S390X and RISC-V CPUs; "
|
||||
"disabling it for V1 backend."
|
||||
)
|
||||
self.enable_prefix_caching = False
|
||||
|
||||
assert self.enable_chunked_prefill is not None
|
||||
|
||||
sliding_window: int | None = None
|
||||
|
||||
@ -412,7 +412,7 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
slot_mapping=slot_mapping,
|
||||
# to ensure inference when chunked_prefill is disabled
|
||||
seq_lens=seq_lens_cpu.tolist(),
|
||||
seq_lens=seq_lens_cpu.tolist()[num_decodes:], # prefill
|
||||
decode_seq_lens_tensor=seq_lens_cpu[:num_decodes], # decode
|
||||
decode_max_seq_len=max_decode_seq_len, # decode
|
||||
decode_block_tables=block_table_tensor[:num_decodes], # decode
|
||||
@ -617,7 +617,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
prefill_meta.prefill_block_tables,
|
||||
self.alibi_slopes,
|
||||
)
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
assert attn_type != AttentionType.ENCODER_ONLY, (
|
||||
"Encoder-only models should not have decode metadata."
|
||||
@ -686,7 +685,12 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
causal_attn = attn_type == AttentionType.DECODER
|
||||
|
||||
seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type)
|
||||
start_q, start_kv = 0, 0
|
||||
# Incoming Q and KV contain decoded tokens as well, hence start at an offset
|
||||
# equal to num_decode_tokens since decode requests appear first
|
||||
start_q, start_kv = (
|
||||
attn_metadata.num_decode_tokens,
|
||||
attn_metadata.num_decode_tokens,
|
||||
)
|
||||
for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv, attn_masks):
|
||||
end_q = start_q + seq_len_q
|
||||
end_kv = start_kv + seq_len_kv
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user