mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 07:44:55 +08:00
[CPU] Refine batch reorder of CPU attention backend (#26096)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
parent
ed3aeb25a4
commit
5c057e068f
@ -14,10 +14,9 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata)
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch.llm.modules as ipex_modules
|
||||
@ -102,16 +101,16 @@ class TorchSDPAMetadata(AttentionMetadata):
|
||||
"""Metadata for PagedAttention."""
|
||||
# (batch_size,). The length of sequences (entire tokens seen so far) per
|
||||
# sequence.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
decode_seq_lens_tensor: Optional[torch.Tensor]
|
||||
# Maximum sequence length in the batch. 0 if it is prefill-only batch.
|
||||
max_decode_seq_len: int
|
||||
decode_max_seq_len: int
|
||||
# (batch_size, max_blocks_per_seq).
|
||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
||||
# in the kv cache. Each block can contain up to block_size tokens.
|
||||
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
||||
# captured.
|
||||
block_tables: Optional[torch.Tensor]
|
||||
decode_block_tables: Optional[torch.Tensor]
|
||||
"""Metadata for TorchSDPABackend.
|
||||
"""
|
||||
# Currently, input sequences can only contain all prompts
|
||||
@ -121,9 +120,9 @@ class TorchSDPAMetadata(AttentionMetadata):
|
||||
|
||||
# For chunked prefill only
|
||||
max_query_len: Optional[int] = None
|
||||
max_kv_len: Optional[int] = None
|
||||
prefill_max_seq_len: Optional[int] = None
|
||||
prefill_query_start_loc: Optional[torch.Tensor] = None
|
||||
kv_start_loc: Optional[torch.Tensor] = None
|
||||
prefill_seq_start_loc: Optional[torch.Tensor] = None
|
||||
prefill_block_tables: Optional[torch.Tensor] = None
|
||||
|
||||
# For V1 logits index only
|
||||
@ -307,8 +306,8 @@ class TorchSDPAMetadata(AttentionMetadata):
|
||||
or attn_type == AttentionType.ENCODER_ONLY):
|
||||
# Decoder self-attention
|
||||
# Choose max_seq_len based on whether we are in prompt_run
|
||||
return (self.seq_lens_tensor, self.max_decode_seq_len,
|
||||
self.block_tables)
|
||||
return (self.decode_seq_lens_tensor, self.decode_max_seq_len,
|
||||
self.decode_block_tables)
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
# Enc/dec cross-attention KVs match encoder sequence length;
|
||||
# cross-attention utilizes special "cross" block tables
|
||||
@ -323,19 +322,14 @@ class TorchSDPAMetadata(AttentionMetadata):
|
||||
|
||||
|
||||
class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
|
||||
reorder_batch_threshold: int = 1
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device) -> None:
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
|
||||
# For reorder
|
||||
self.reorder_prompt_req_index_list = np.empty(
|
||||
vllm_config.scheduler_config.max_num_seqs, dtype=np.int64)
|
||||
self.reorder_decode_req_index_list = np.empty(
|
||||
vllm_config.scheduler_config.max_num_seqs, dtype=np.int64)
|
||||
self.num_prompt_req: int = 0
|
||||
self._init_reorder_batch_threshold(1, False)
|
||||
|
||||
self.seq_start_loc_cpu = torch.zeros(
|
||||
vllm_config.scheduler_config.max_num_seqs + 1,
|
||||
@ -344,50 +338,6 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
|
||||
)
|
||||
self.seq_start_loc_np = self.seq_start_loc_cpu.numpy()
|
||||
|
||||
def reorder_batch(self, input_batch: InputBatch,
|
||||
scheduler_output: SchedulerOutput) -> bool:
|
||||
prompt_list_idx = 0
|
||||
decode_list_idx = 0
|
||||
for req_index in range(input_batch.num_reqs):
|
||||
if input_batch.num_computed_tokens_cpu[
|
||||
req_index] < input_batch.num_prompt_tokens[req_index]:
|
||||
# prompt stage
|
||||
self.reorder_prompt_req_index_list[prompt_list_idx] = req_index
|
||||
prompt_list_idx += 1
|
||||
else:
|
||||
# decode stage
|
||||
self.reorder_decode_req_index_list[decode_list_idx] = req_index
|
||||
decode_list_idx += 1
|
||||
assert decode_list_idx + prompt_list_idx == input_batch.num_reqs
|
||||
|
||||
# Update prompt requests number
|
||||
self.num_prompt_req = prompt_list_idx
|
||||
|
||||
reorder_req_num = 0
|
||||
for req_index in range(decode_list_idx):
|
||||
if self.reorder_decode_req_index_list[req_index] < prompt_list_idx:
|
||||
reorder_req_num += 1
|
||||
else:
|
||||
break
|
||||
|
||||
if reorder_req_num == 0:
|
||||
return False
|
||||
|
||||
reorder_prompt_list = (
|
||||
self.reorder_prompt_req_index_list[:prompt_list_idx]
|
||||
[-reorder_req_num:])
|
||||
reorder_decode_list = (
|
||||
self.reorder_decode_req_index_list[:decode_list_idx]
|
||||
[:reorder_req_num])
|
||||
assert reorder_decode_list.size == reorder_prompt_list.size
|
||||
|
||||
for idx in range(reorder_req_num):
|
||||
prompt_req_index = reorder_prompt_list[idx].item()
|
||||
decode_req_index = reorder_decode_list[idx].item()
|
||||
input_batch.swap_states(prompt_req_index, decode_req_index)
|
||||
|
||||
return True
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
@ -397,41 +347,46 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
|
||||
|
||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||
seq_lens_np = seq_lens_cpu.numpy()
|
||||
num_prompt_req = self.num_prompt_req
|
||||
max_prefill_seq_len = seq_lens_np[:num_prompt_req].max().item(
|
||||
) if num_prompt_req > 0 else 0
|
||||
max_decode_seq_len = seq_lens_np[num_prompt_req:num_reqs].max().item(
|
||||
) if num_prompt_req < num_reqs else 0
|
||||
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
query_start_loc_np = query_start_loc_cpu.numpy()
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
|
||||
split_decodes_and_prefills(common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold,
|
||||
require_uniform=True)
|
||||
|
||||
max_prefill_seq_len = seq_lens_np[num_decodes:num_reqs].max().item(
|
||||
) if num_prefills > 0 else 0
|
||||
max_decode_seq_len = seq_lens_np[:num_decodes].max().item(
|
||||
) if num_prefills < num_reqs else 0
|
||||
self.seq_start_loc_np[0] = 0
|
||||
np.cumsum(seq_lens_np, out=self.seq_start_loc_np[1:num_reqs + 1])
|
||||
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
num_prefill_tokens = int(query_start_loc_cpu[num_prompt_req].item())
|
||||
num_decode_tokens = int(query_start_loc_cpu[num_reqs].item() -
|
||||
num_prefill_tokens)
|
||||
|
||||
slot_mapping = common_attn_metadata.slot_mapping.long()
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
query_start_loc_np = query_start_loc_cpu.numpy()
|
||||
query_start_loc_np[num_decodes:num_reqs + 1] -= num_decode_tokens
|
||||
|
||||
attn_metadata = TorchSDPAMetadata(
|
||||
num_prefills=num_prompt_req,
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
slot_mapping=slot_mapping,
|
||||
# to ensure inference when chunked_prefill is disabled
|
||||
seq_lens=seq_lens_cpu.tolist(),
|
||||
seq_lens_tensor=seq_lens_cpu[num_prompt_req:num_reqs], # decode
|
||||
max_decode_seq_len=max_decode_seq_len, # decode
|
||||
block_tables=block_table_tensor[num_prompt_req:num_reqs], # decode
|
||||
decode_seq_lens_tensor=seq_lens_cpu[:num_decodes], # decode
|
||||
decode_max_seq_len=max_decode_seq_len, # decode
|
||||
decode_block_tables=block_table_tensor[:num_decodes], # decode
|
||||
chunked_prefill=self.scheduler_config.chunked_prefill_enabled,
|
||||
max_query_len=max_query_len,
|
||||
max_kv_len=max_prefill_seq_len,
|
||||
prefill_query_start_loc=query_start_loc_cpu[:num_prompt_req +
|
||||
prefill_max_seq_len=max_prefill_seq_len,
|
||||
prefill_query_start_loc=query_start_loc_cpu[num_decodes:num_reqs +
|
||||
1], # prefill
|
||||
kv_start_loc=self.seq_start_loc_cpu[:num_prompt_req +
|
||||
1], # prefill
|
||||
prefill_block_tables=block_table_tensor[:
|
||||
num_prompt_req], # prefill
|
||||
prefill_seq_start_loc=self.seq_start_loc_cpu[num_decodes:num_reqs +
|
||||
1], # prefill
|
||||
prefill_block_tables=block_table_tensor[
|
||||
num_decodes:num_reqs], # prefill
|
||||
query_start_loc=query_start_loc_cpu[:num_reqs +
|
||||
1], # for logits index
|
||||
)
|
||||
@ -596,14 +551,14 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
import intel_extension_for_pytorch.llm.modules as ipex_modules
|
||||
output = torch.empty_like(query)
|
||||
ipex_modules.PagedAttention.flash_attn_varlen_func(
|
||||
output[:prefill_meta.num_prefill_tokens, :, :],
|
||||
query[:prefill_meta.num_prefill_tokens, :, :],
|
||||
output[prefill_meta.num_decode_tokens:, :, :],
|
||||
query[prefill_meta.num_decode_tokens:, :, :],
|
||||
key_cache,
|
||||
value_cache,
|
||||
prefill_meta.prefill_query_start_loc,
|
||||
prefill_meta.kv_start_loc,
|
||||
prefill_meta.prefill_seq_start_loc,
|
||||
prefill_meta.max_query_len,
|
||||
prefill_meta.max_kv_len,
|
||||
prefill_meta.prefill_max_seq_len,
|
||||
self.scale,
|
||||
True,
|
||||
prefill_meta.prefill_block_tables,
|
||||
@ -621,8 +576,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
) = decode_meta.get_seq_len_block_table_args(attn_type)
|
||||
|
||||
self.paged_attn_impl.forward_decode(
|
||||
output[attn_metadata.num_prefill_tokens:, :, :],
|
||||
query[attn_metadata.num_prefill_tokens:, :, :],
|
||||
output[:attn_metadata.num_decode_tokens, :, :],
|
||||
query[:attn_metadata.num_decode_tokens, :, :],
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_tables_arg,
|
||||
|
||||
@ -9,7 +9,6 @@ import torch.nn as nn
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.v1.attention.backends.cpu_attn import TorchSDPAMetadataBuilderV1
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
@ -33,50 +32,12 @@ class CPUModelRunner(GPUModelRunner):
|
||||
|
||||
self._postprocess_tensors()
|
||||
|
||||
# Note: Remove the override after new attention backend finished
|
||||
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
"""
|
||||
Update the order of requests in the batch based on the attention
|
||||
backend's needs. For example, some attention backends (namely MLA) may
|
||||
want to separate requests based on if the attention computation will be
|
||||
compute-bound or memory-bound.
|
||||
|
||||
Args:
|
||||
scheduler_output: The scheduler output.
|
||||
"""
|
||||
# Attention free models have zero kv_cache_groups, however models
|
||||
# like Mamba are also attention free but use the kv_cache for
|
||||
# keeping its internal state. This is why we check the number
|
||||
# of kv_cache groups instead of solely checking
|
||||
# for self.model_config.is_attention_free.
|
||||
if len(self.kv_cache_config.kv_cache_groups) == 0:
|
||||
return
|
||||
|
||||
if len(self.kv_cache_config.kv_cache_groups) > 1:
|
||||
raise ValueError("Multiple KVCacheGroups is not"
|
||||
"currently supported with CPU model runner.")
|
||||
|
||||
# Guard against encoder-only / pooling models where `attn_groups`
|
||||
# may be empty or lack the expected metadata_builder.
|
||||
# Without this check, accessing `attn_groups[0][0]` would trigger
|
||||
# an AssertionError on CPU backend.
|
||||
if not hasattr(self, "attn_groups") or not self.attn_groups:
|
||||
return
|
||||
if not self.attn_groups[0]:
|
||||
return
|
||||
|
||||
mb = getattr(self.attn_groups[0][0], "metadata_builders", None)
|
||||
if isinstance(mb, list):
|
||||
if not isinstance(mb[0], TorchSDPAMetadataBuilderV1):
|
||||
return
|
||||
mb[0].reorder_batch(self.input_batch, scheduler_output)
|
||||
return
|
||||
elif not isinstance(mb, TorchSDPAMetadataBuilderV1):
|
||||
# Encoder-only / rerank models do not benefit from reordering,
|
||||
# so we safely skip here.
|
||||
return
|
||||
|
||||
# Safe path for decoder/attention-heavy models
|
||||
mb.reorder_batch(self.input_batch, scheduler_output)
|
||||
super()._may_reorder_batch(scheduler_output)
|
||||
|
||||
def _postprocess_tensors(self) -> None:
|
||||
# Note: replace device tensors with cpu tensors
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user