[CPU] Refine batch reorder of CPU attention backend (#26096)

Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
Li, Jiang 2025-10-04 21:54:35 +08:00 committed by GitHub
parent ed3aeb25a4
commit 5c057e068f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 44 additions and 128 deletions

View File

@ -14,10 +14,9 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.core.sched.output import SchedulerOutput
CommonAttentionMetadata,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.gpu_input_batch import InputBatch
try:
import intel_extension_for_pytorch.llm.modules as ipex_modules
@ -102,16 +101,16 @@ class TorchSDPAMetadata(AttentionMetadata):
"""Metadata for PagedAttention."""
# (batch_size,). The length of sequences (entire tokens seen so far) per
# sequence.
seq_lens_tensor: Optional[torch.Tensor]
decode_seq_lens_tensor: Optional[torch.Tensor]
# Maximum sequence length in the batch. 0 if it is prefill-only batch.
max_decode_seq_len: int
decode_max_seq_len: int
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
decode_block_tables: Optional[torch.Tensor]
"""Metadata for TorchSDPABackend.
"""
# Currently, input sequences can only contain all prompts
@ -121,9 +120,9 @@ class TorchSDPAMetadata(AttentionMetadata):
# For chunked prefill only
max_query_len: Optional[int] = None
max_kv_len: Optional[int] = None
prefill_max_seq_len: Optional[int] = None
prefill_query_start_loc: Optional[torch.Tensor] = None
kv_start_loc: Optional[torch.Tensor] = None
prefill_seq_start_loc: Optional[torch.Tensor] = None
prefill_block_tables: Optional[torch.Tensor] = None
# For V1 logits index only
@ -307,8 +306,8 @@ class TorchSDPAMetadata(AttentionMetadata):
or attn_type == AttentionType.ENCODER_ONLY):
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
return (self.seq_lens_tensor, self.max_decode_seq_len,
self.block_tables)
return (self.decode_seq_lens_tensor, self.decode_max_seq_len,
self.decode_block_tables)
elif attn_type == AttentionType.ENCODER_DECODER:
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
@ -323,19 +322,14 @@ class TorchSDPAMetadata(AttentionMetadata):
class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
reorder_batch_threshold: int = 1
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device) -> None:
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.scheduler_config = vllm_config.scheduler_config
# For reorder
self.reorder_prompt_req_index_list = np.empty(
vllm_config.scheduler_config.max_num_seqs, dtype=np.int64)
self.reorder_decode_req_index_list = np.empty(
vllm_config.scheduler_config.max_num_seqs, dtype=np.int64)
self.num_prompt_req: int = 0
self._init_reorder_batch_threshold(1, False)
self.seq_start_loc_cpu = torch.zeros(
vllm_config.scheduler_config.max_num_seqs + 1,
@ -344,50 +338,6 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
)
self.seq_start_loc_np = self.seq_start_loc_cpu.numpy()
def reorder_batch(self, input_batch: InputBatch,
scheduler_output: SchedulerOutput) -> bool:
prompt_list_idx = 0
decode_list_idx = 0
for req_index in range(input_batch.num_reqs):
if input_batch.num_computed_tokens_cpu[
req_index] < input_batch.num_prompt_tokens[req_index]:
# prompt stage
self.reorder_prompt_req_index_list[prompt_list_idx] = req_index
prompt_list_idx += 1
else:
# decode stage
self.reorder_decode_req_index_list[decode_list_idx] = req_index
decode_list_idx += 1
assert decode_list_idx + prompt_list_idx == input_batch.num_reqs
# Update prompt requests number
self.num_prompt_req = prompt_list_idx
reorder_req_num = 0
for req_index in range(decode_list_idx):
if self.reorder_decode_req_index_list[req_index] < prompt_list_idx:
reorder_req_num += 1
else:
break
if reorder_req_num == 0:
return False
reorder_prompt_list = (
self.reorder_prompt_req_index_list[:prompt_list_idx]
[-reorder_req_num:])
reorder_decode_list = (
self.reorder_decode_req_index_list[:decode_list_idx]
[:reorder_req_num])
assert reorder_decode_list.size == reorder_prompt_list.size
for idx in range(reorder_req_num):
prompt_req_index = reorder_prompt_list[idx].item()
decode_req_index = reorder_decode_list[idx].item()
input_batch.swap_states(prompt_req_index, decode_req_index)
return True
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
@ -397,41 +347,46 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
seq_lens_np = seq_lens_cpu.numpy()
num_prompt_req = self.num_prompt_req
max_prefill_seq_len = seq_lens_np[:num_prompt_req].max().item(
) if num_prompt_req > 0 else 0
max_decode_seq_len = seq_lens_np[num_prompt_req:num_reqs].max().item(
) if num_prompt_req < num_reqs else 0
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
query_start_loc_np = query_start_loc_cpu.numpy()
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
split_decodes_and_prefills(common_attn_metadata,
decode_threshold=self.reorder_batch_threshold,
require_uniform=True)
max_prefill_seq_len = seq_lens_np[num_decodes:num_reqs].max().item(
) if num_prefills > 0 else 0
max_decode_seq_len = seq_lens_np[:num_decodes].max().item(
) if num_prefills < num_reqs else 0
self.seq_start_loc_np[0] = 0
np.cumsum(seq_lens_np, out=self.seq_start_loc_np[1:num_reqs + 1])
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
num_prefill_tokens = int(query_start_loc_cpu[num_prompt_req].item())
num_decode_tokens = int(query_start_loc_cpu[num_reqs].item() -
num_prefill_tokens)
slot_mapping = common_attn_metadata.slot_mapping.long()
block_table_tensor = common_attn_metadata.block_table_tensor
query_start_loc_np = query_start_loc_cpu.numpy()
query_start_loc_np[num_decodes:num_reqs + 1] -= num_decode_tokens
attn_metadata = TorchSDPAMetadata(
num_prefills=num_prompt_req,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
slot_mapping=slot_mapping,
# to ensure inference when chunked_prefill is disabled
seq_lens=seq_lens_cpu.tolist(),
seq_lens_tensor=seq_lens_cpu[num_prompt_req:num_reqs], # decode
max_decode_seq_len=max_decode_seq_len, # decode
block_tables=block_table_tensor[num_prompt_req:num_reqs], # decode
decode_seq_lens_tensor=seq_lens_cpu[:num_decodes], # decode
decode_max_seq_len=max_decode_seq_len, # decode
decode_block_tables=block_table_tensor[:num_decodes], # decode
chunked_prefill=self.scheduler_config.chunked_prefill_enabled,
max_query_len=max_query_len,
max_kv_len=max_prefill_seq_len,
prefill_query_start_loc=query_start_loc_cpu[:num_prompt_req +
prefill_max_seq_len=max_prefill_seq_len,
prefill_query_start_loc=query_start_loc_cpu[num_decodes:num_reqs +
1], # prefill
kv_start_loc=self.seq_start_loc_cpu[:num_prompt_req +
1], # prefill
prefill_block_tables=block_table_tensor[:
num_prompt_req], # prefill
prefill_seq_start_loc=self.seq_start_loc_cpu[num_decodes:num_reqs +
1], # prefill
prefill_block_tables=block_table_tensor[
num_decodes:num_reqs], # prefill
query_start_loc=query_start_loc_cpu[:num_reqs +
1], # for logits index
)
@ -596,14 +551,14 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
import intel_extension_for_pytorch.llm.modules as ipex_modules
output = torch.empty_like(query)
ipex_modules.PagedAttention.flash_attn_varlen_func(
output[:prefill_meta.num_prefill_tokens, :, :],
query[:prefill_meta.num_prefill_tokens, :, :],
output[prefill_meta.num_decode_tokens:, :, :],
query[prefill_meta.num_decode_tokens:, :, :],
key_cache,
value_cache,
prefill_meta.prefill_query_start_loc,
prefill_meta.kv_start_loc,
prefill_meta.prefill_seq_start_loc,
prefill_meta.max_query_len,
prefill_meta.max_kv_len,
prefill_meta.prefill_max_seq_len,
self.scale,
True,
prefill_meta.prefill_block_tables,
@ -621,8 +576,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
) = decode_meta.get_seq_len_block_table_args(attn_type)
self.paged_attn_impl.forward_decode(
output[attn_metadata.num_prefill_tokens:, :, :],
query[attn_metadata.num_prefill_tokens:, :, :],
output[:attn_metadata.num_decode_tokens, :, :],
query[:attn_metadata.num_decode_tokens, :, :],
key_cache,
value_cache,
block_tables_arg,

View File

@ -9,7 +9,6 @@ import torch.nn as nn
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.v1.attention.backends.cpu_attn import TorchSDPAMetadataBuilderV1
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
@ -33,50 +32,12 @@ class CPUModelRunner(GPUModelRunner):
self._postprocess_tensors()
# Note: Remove the override after new attention backend finished
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
"""
Update the order of requests in the batch based on the attention
backend's needs. For example, some attention backends (namely MLA) may
want to separate requests based on if the attention computation will be
compute-bound or memory-bound.
Args:
scheduler_output: The scheduler output.
"""
# Attention free models have zero kv_cache_groups, however models
# like Mamba are also attention free but use the kv_cache for
# keeping its internal state. This is why we check the number
# of kv_cache groups instead of solely checking
# for self.model_config.is_attention_free.
if len(self.kv_cache_config.kv_cache_groups) == 0:
return
if len(self.kv_cache_config.kv_cache_groups) > 1:
raise ValueError("Multiple KVCacheGroups is not"
"currently supported with CPU model runner.")
# Guard against encoder-only / pooling models where `attn_groups`
# may be empty or lack the expected metadata_builder.
# Without this check, accessing `attn_groups[0][0]` would trigger
# an AssertionError on CPU backend.
if not hasattr(self, "attn_groups") or not self.attn_groups:
return
if not self.attn_groups[0]:
return
mb = getattr(self.attn_groups[0][0], "metadata_builders", None)
if isinstance(mb, list):
if not isinstance(mb[0], TorchSDPAMetadataBuilderV1):
return
mb[0].reorder_batch(self.input_batch, scheduler_output)
return
elif not isinstance(mb, TorchSDPAMetadataBuilderV1):
# Encoder-only / rerank models do not benefit from reordering,
# so we safely skip here.
return
# Safe path for decoder/attention-heavy models
mb.reorder_batch(self.input_batch, scheduler_output)
super()._may_reorder_batch(scheduler_output)
def _postprocess_tensors(self) -> None:
# Note: replace device tensors with cpu tensors