diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 21c2a9e44ba9..d373eb4906b1 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1293,7 +1293,8 @@ class EngineArgs: # Set default arguments for V1 Engine. self._set_default_args(usage_context, model_config) - # Disable chunked prefill for POWER (ppc64le)/ARM/s390x/RISCV CPUs in V1 + # Disable chunked prefill and prefix caching for: + # POWER (ppc64le)/ARM/s390x/RISCV CPUs in V1 if current_platform.is_cpu() and current_platform.get_cpu_architecture() in ( CpuArchEnum.POWERPC, CpuArchEnum.S390X, @@ -1306,6 +1307,13 @@ class EngineArgs: "disabling it for V1 backend." ) self.enable_chunked_prefill = False + logger.info( + "Prefix caching is not supported for ARM and POWER, " + "S390X and RISC-V CPUs; " + "disabling it for V1 backend." + ) + self.enable_prefix_caching = False + assert self.enable_chunked_prefill is not None sliding_window: int | None = None diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 211eefdb6c11..0d3e1729ff20 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -412,7 +412,7 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]): num_decode_tokens=num_decode_tokens, slot_mapping=slot_mapping, # to ensure inference when chunked_prefill is disabled - seq_lens=seq_lens_cpu.tolist(), + seq_lens=seq_lens_cpu.tolist()[num_decodes:], # prefill decode_seq_lens_tensor=seq_lens_cpu[:num_decodes], # decode decode_max_seq_len=max_decode_seq_len, # decode decode_block_tables=block_table_tensor[:num_decodes], # decode @@ -617,7 +617,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): prefill_meta.prefill_block_tables, self.alibi_slopes, ) - if decode_meta := attn_metadata.decode_metadata: assert attn_type != AttentionType.ENCODER_ONLY, ( "Encoder-only models should not have decode metadata." @@ -686,7 +685,12 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): causal_attn = attn_type == AttentionType.DECODER seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type) - start_q, start_kv = 0, 0 + # Incoming Q and KV contain decoded tokens as well, hence start at an offset + # equal to num_decode_tokens since decode requests appear first + start_q, start_kv = ( + attn_metadata.num_decode_tokens, + attn_metadata.num_decode_tokens, + ) for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv, attn_masks): end_q = start_q + seq_len_q end_kv = start_kv + seq_len_kv