diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 6ca0c63f6b59..369f7062005c 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -14,10 +14,9 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata) -from vllm.v1.core.sched.output import SchedulerOutput + CommonAttentionMetadata, + split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec -from vllm.v1.worker.gpu_input_batch import InputBatch try: import intel_extension_for_pytorch.llm.modules as ipex_modules @@ -102,16 +101,16 @@ class TorchSDPAMetadata(AttentionMetadata): """Metadata for PagedAttention.""" # (batch_size,). The length of sequences (entire tokens seen so far) per # sequence. - seq_lens_tensor: Optional[torch.Tensor] + decode_seq_lens_tensor: Optional[torch.Tensor] # Maximum sequence length in the batch. 0 if it is prefill-only batch. - max_decode_seq_len: int + decode_max_seq_len: int # (batch_size, max_blocks_per_seq). # Block addresses per sequence. (Seq id -> list of physical block) # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks # in the kv cache. Each block can contain up to block_size tokens. # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph # captured. - block_tables: Optional[torch.Tensor] + decode_block_tables: Optional[torch.Tensor] """Metadata for TorchSDPABackend. """ # Currently, input sequences can only contain all prompts @@ -121,9 +120,9 @@ class TorchSDPAMetadata(AttentionMetadata): # For chunked prefill only max_query_len: Optional[int] = None - max_kv_len: Optional[int] = None + prefill_max_seq_len: Optional[int] = None prefill_query_start_loc: Optional[torch.Tensor] = None - kv_start_loc: Optional[torch.Tensor] = None + prefill_seq_start_loc: Optional[torch.Tensor] = None prefill_block_tables: Optional[torch.Tensor] = None # For V1 logits index only @@ -307,8 +306,8 @@ class TorchSDPAMetadata(AttentionMetadata): or attn_type == AttentionType.ENCODER_ONLY): # Decoder self-attention # Choose max_seq_len based on whether we are in prompt_run - return (self.seq_lens_tensor, self.max_decode_seq_len, - self.block_tables) + return (self.decode_seq_lens_tensor, self.decode_max_seq_len, + self.decode_block_tables) elif attn_type == AttentionType.ENCODER_DECODER: # Enc/dec cross-attention KVs match encoder sequence length; # cross-attention utilizes special "cross" block tables @@ -323,19 +322,14 @@ class TorchSDPAMetadata(AttentionMetadata): class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]): + reorder_batch_threshold: int = 1 def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device) -> None: super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.scheduler_config = vllm_config.scheduler_config - - # For reorder - self.reorder_prompt_req_index_list = np.empty( - vllm_config.scheduler_config.max_num_seqs, dtype=np.int64) - self.reorder_decode_req_index_list = np.empty( - vllm_config.scheduler_config.max_num_seqs, dtype=np.int64) - self.num_prompt_req: int = 0 + self._init_reorder_batch_threshold(1, False) self.seq_start_loc_cpu = torch.zeros( vllm_config.scheduler_config.max_num_seqs + 1, @@ -344,50 +338,6 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]): ) self.seq_start_loc_np = self.seq_start_loc_cpu.numpy() - def reorder_batch(self, input_batch: InputBatch, - scheduler_output: SchedulerOutput) -> bool: - prompt_list_idx = 0 - decode_list_idx = 0 - for req_index in range(input_batch.num_reqs): - if input_batch.num_computed_tokens_cpu[ - req_index] < input_batch.num_prompt_tokens[req_index]: - # prompt stage - self.reorder_prompt_req_index_list[prompt_list_idx] = req_index - prompt_list_idx += 1 - else: - # decode stage - self.reorder_decode_req_index_list[decode_list_idx] = req_index - decode_list_idx += 1 - assert decode_list_idx + prompt_list_idx == input_batch.num_reqs - - # Update prompt requests number - self.num_prompt_req = prompt_list_idx - - reorder_req_num = 0 - for req_index in range(decode_list_idx): - if self.reorder_decode_req_index_list[req_index] < prompt_list_idx: - reorder_req_num += 1 - else: - break - - if reorder_req_num == 0: - return False - - reorder_prompt_list = ( - self.reorder_prompt_req_index_list[:prompt_list_idx] - [-reorder_req_num:]) - reorder_decode_list = ( - self.reorder_decode_req_index_list[:decode_list_idx] - [:reorder_req_num]) - assert reorder_decode_list.size == reorder_prompt_list.size - - for idx in range(reorder_req_num): - prompt_req_index = reorder_prompt_list[idx].item() - decode_req_index = reorder_decode_list[idx].item() - input_batch.swap_states(prompt_req_index, decode_req_index) - - return True - def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, @@ -397,41 +347,46 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]): seq_lens_cpu = common_attn_metadata.seq_lens_cpu seq_lens_np = seq_lens_cpu.numpy() - num_prompt_req = self.num_prompt_req - max_prefill_seq_len = seq_lens_np[:num_prompt_req].max().item( - ) if num_prompt_req > 0 else 0 - max_decode_seq_len = seq_lens_np[num_prompt_req:num_reqs].max().item( - ) if num_prompt_req < num_reqs else 0 + + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + query_start_loc_np = query_start_loc_cpu.numpy() + + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\ + split_decodes_and_prefills(common_attn_metadata, + decode_threshold=self.reorder_batch_threshold, + require_uniform=True) + + max_prefill_seq_len = seq_lens_np[num_decodes:num_reqs].max().item( + ) if num_prefills > 0 else 0 + max_decode_seq_len = seq_lens_np[:num_decodes].max().item( + ) if num_prefills < num_reqs else 0 self.seq_start_loc_np[0] = 0 np.cumsum(seq_lens_np, out=self.seq_start_loc_np[1:num_reqs + 1]) - query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - num_prefill_tokens = int(query_start_loc_cpu[num_prompt_req].item()) - num_decode_tokens = int(query_start_loc_cpu[num_reqs].item() - - num_prefill_tokens) - slot_mapping = common_attn_metadata.slot_mapping.long() block_table_tensor = common_attn_metadata.block_table_tensor + query_start_loc_np = query_start_loc_cpu.numpy() + query_start_loc_np[num_decodes:num_reqs + 1] -= num_decode_tokens attn_metadata = TorchSDPAMetadata( - num_prefills=num_prompt_req, + num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, slot_mapping=slot_mapping, # to ensure inference when chunked_prefill is disabled seq_lens=seq_lens_cpu.tolist(), - seq_lens_tensor=seq_lens_cpu[num_prompt_req:num_reqs], # decode - max_decode_seq_len=max_decode_seq_len, # decode - block_tables=block_table_tensor[num_prompt_req:num_reqs], # decode + decode_seq_lens_tensor=seq_lens_cpu[:num_decodes], # decode + decode_max_seq_len=max_decode_seq_len, # decode + decode_block_tables=block_table_tensor[:num_decodes], # decode chunked_prefill=self.scheduler_config.chunked_prefill_enabled, max_query_len=max_query_len, - max_kv_len=max_prefill_seq_len, - prefill_query_start_loc=query_start_loc_cpu[:num_prompt_req + + prefill_max_seq_len=max_prefill_seq_len, + prefill_query_start_loc=query_start_loc_cpu[num_decodes:num_reqs + 1], # prefill - kv_start_loc=self.seq_start_loc_cpu[:num_prompt_req + - 1], # prefill - prefill_block_tables=block_table_tensor[: - num_prompt_req], # prefill + prefill_seq_start_loc=self.seq_start_loc_cpu[num_decodes:num_reqs + + 1], # prefill + prefill_block_tables=block_table_tensor[ + num_decodes:num_reqs], # prefill query_start_loc=query_start_loc_cpu[:num_reqs + 1], # for logits index ) @@ -596,14 +551,14 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): import intel_extension_for_pytorch.llm.modules as ipex_modules output = torch.empty_like(query) ipex_modules.PagedAttention.flash_attn_varlen_func( - output[:prefill_meta.num_prefill_tokens, :, :], - query[:prefill_meta.num_prefill_tokens, :, :], + output[prefill_meta.num_decode_tokens:, :, :], + query[prefill_meta.num_decode_tokens:, :, :], key_cache, value_cache, prefill_meta.prefill_query_start_loc, - prefill_meta.kv_start_loc, + prefill_meta.prefill_seq_start_loc, prefill_meta.max_query_len, - prefill_meta.max_kv_len, + prefill_meta.prefill_max_seq_len, self.scale, True, prefill_meta.prefill_block_tables, @@ -621,8 +576,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ) = decode_meta.get_seq_len_block_table_args(attn_type) self.paged_attn_impl.forward_decode( - output[attn_metadata.num_prefill_tokens:, :, :], - query[attn_metadata.num_prefill_tokens:, :, :], + output[:attn_metadata.num_decode_tokens, :, :], + query[:attn_metadata.num_decode_tokens, :, :], key_cache, value_cache, block_tables_arg, diff --git a/vllm/v1/worker/cpu_model_runner.py b/vllm/v1/worker/cpu_model_runner.py index 6a97f7ebc3fc..964e4c6b2383 100644 --- a/vllm/v1/worker/cpu_model_runner.py +++ b/vllm/v1/worker/cpu_model_runner.py @@ -9,7 +9,6 @@ import torch.nn as nn from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model -from vllm.v1.attention.backends.cpu_attn import TorchSDPAMetadataBuilderV1 from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_model_runner import GPUModelRunner @@ -33,50 +32,12 @@ class CPUModelRunner(GPUModelRunner): self._postprocess_tensors() + # Note: Remove the override after new attention backend finished def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: - """ - Update the order of requests in the batch based on the attention - backend's needs. For example, some attention backends (namely MLA) may - want to separate requests based on if the attention computation will be - compute-bound or memory-bound. - - Args: - scheduler_output: The scheduler output. - """ - # Attention free models have zero kv_cache_groups, however models - # like Mamba are also attention free but use the kv_cache for - # keeping its internal state. This is why we check the number - # of kv_cache groups instead of solely checking - # for self.model_config.is_attention_free. - if len(self.kv_cache_config.kv_cache_groups) == 0: - return - if len(self.kv_cache_config.kv_cache_groups) > 1: raise ValueError("Multiple KVCacheGroups is not" "currently supported with CPU model runner.") - - # Guard against encoder-only / pooling models where `attn_groups` - # may be empty or lack the expected metadata_builder. - # Without this check, accessing `attn_groups[0][0]` would trigger - # an AssertionError on CPU backend. - if not hasattr(self, "attn_groups") or not self.attn_groups: - return - if not self.attn_groups[0]: - return - - mb = getattr(self.attn_groups[0][0], "metadata_builders", None) - if isinstance(mb, list): - if not isinstance(mb[0], TorchSDPAMetadataBuilderV1): - return - mb[0].reorder_batch(self.input_batch, scheduler_output) - return - elif not isinstance(mb, TorchSDPAMetadataBuilderV1): - # Encoder-only / rerank models do not benefit from reordering, - # so we safely skip here. - return - - # Safe path for decoder/attention-heavy models - mb.reorder_batch(self.input_batch, scheduler_output) + super()._may_reorder_batch(scheduler_output) def _postprocess_tensors(self) -> None: # Note: replace device tensors with cpu tensors