mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 10:29:50 +08:00
[CPU] Refine batch reorder of CPU attention backend (#26096)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
parent
ed3aeb25a4
commit
5c057e068f
@ -14,10 +14,9 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||||
CommonAttentionMetadata)
|
CommonAttentionMetadata,
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
split_decodes_and_prefills)
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import intel_extension_for_pytorch.llm.modules as ipex_modules
|
import intel_extension_for_pytorch.llm.modules as ipex_modules
|
||||||
@ -102,16 +101,16 @@ class TorchSDPAMetadata(AttentionMetadata):
|
|||||||
"""Metadata for PagedAttention."""
|
"""Metadata for PagedAttention."""
|
||||||
# (batch_size,). The length of sequences (entire tokens seen so far) per
|
# (batch_size,). The length of sequences (entire tokens seen so far) per
|
||||||
# sequence.
|
# sequence.
|
||||||
seq_lens_tensor: Optional[torch.Tensor]
|
decode_seq_lens_tensor: Optional[torch.Tensor]
|
||||||
# Maximum sequence length in the batch. 0 if it is prefill-only batch.
|
# Maximum sequence length in the batch. 0 if it is prefill-only batch.
|
||||||
max_decode_seq_len: int
|
decode_max_seq_len: int
|
||||||
# (batch_size, max_blocks_per_seq).
|
# (batch_size, max_blocks_per_seq).
|
||||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||||
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
||||||
# in the kv cache. Each block can contain up to block_size tokens.
|
# in the kv cache. Each block can contain up to block_size tokens.
|
||||||
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
||||||
# captured.
|
# captured.
|
||||||
block_tables: Optional[torch.Tensor]
|
decode_block_tables: Optional[torch.Tensor]
|
||||||
"""Metadata for TorchSDPABackend.
|
"""Metadata for TorchSDPABackend.
|
||||||
"""
|
"""
|
||||||
# Currently, input sequences can only contain all prompts
|
# Currently, input sequences can only contain all prompts
|
||||||
@ -121,9 +120,9 @@ class TorchSDPAMetadata(AttentionMetadata):
|
|||||||
|
|
||||||
# For chunked prefill only
|
# For chunked prefill only
|
||||||
max_query_len: Optional[int] = None
|
max_query_len: Optional[int] = None
|
||||||
max_kv_len: Optional[int] = None
|
prefill_max_seq_len: Optional[int] = None
|
||||||
prefill_query_start_loc: Optional[torch.Tensor] = None
|
prefill_query_start_loc: Optional[torch.Tensor] = None
|
||||||
kv_start_loc: Optional[torch.Tensor] = None
|
prefill_seq_start_loc: Optional[torch.Tensor] = None
|
||||||
prefill_block_tables: Optional[torch.Tensor] = None
|
prefill_block_tables: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# For V1 logits index only
|
# For V1 logits index only
|
||||||
@ -307,8 +306,8 @@ class TorchSDPAMetadata(AttentionMetadata):
|
|||||||
or attn_type == AttentionType.ENCODER_ONLY):
|
or attn_type == AttentionType.ENCODER_ONLY):
|
||||||
# Decoder self-attention
|
# Decoder self-attention
|
||||||
# Choose max_seq_len based on whether we are in prompt_run
|
# Choose max_seq_len based on whether we are in prompt_run
|
||||||
return (self.seq_lens_tensor, self.max_decode_seq_len,
|
return (self.decode_seq_lens_tensor, self.decode_max_seq_len,
|
||||||
self.block_tables)
|
self.decode_block_tables)
|
||||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||||
# Enc/dec cross-attention KVs match encoder sequence length;
|
# Enc/dec cross-attention KVs match encoder sequence length;
|
||||||
# cross-attention utilizes special "cross" block tables
|
# cross-attention utilizes special "cross" block tables
|
||||||
@ -323,19 +322,14 @@ class TorchSDPAMetadata(AttentionMetadata):
|
|||||||
|
|
||||||
|
|
||||||
class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
|
class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
|
||||||
|
reorder_batch_threshold: int = 1
|
||||||
|
|
||||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||||
vllm_config: VllmConfig, device: torch.device) -> None:
|
vllm_config: VllmConfig, device: torch.device) -> None:
|
||||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||||
|
|
||||||
self.scheduler_config = vllm_config.scheduler_config
|
self.scheduler_config = vllm_config.scheduler_config
|
||||||
|
self._init_reorder_batch_threshold(1, False)
|
||||||
# For reorder
|
|
||||||
self.reorder_prompt_req_index_list = np.empty(
|
|
||||||
vllm_config.scheduler_config.max_num_seqs, dtype=np.int64)
|
|
||||||
self.reorder_decode_req_index_list = np.empty(
|
|
||||||
vllm_config.scheduler_config.max_num_seqs, dtype=np.int64)
|
|
||||||
self.num_prompt_req: int = 0
|
|
||||||
|
|
||||||
self.seq_start_loc_cpu = torch.zeros(
|
self.seq_start_loc_cpu = torch.zeros(
|
||||||
vllm_config.scheduler_config.max_num_seqs + 1,
|
vllm_config.scheduler_config.max_num_seqs + 1,
|
||||||
@ -344,50 +338,6 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
|
|||||||
)
|
)
|
||||||
self.seq_start_loc_np = self.seq_start_loc_cpu.numpy()
|
self.seq_start_loc_np = self.seq_start_loc_cpu.numpy()
|
||||||
|
|
||||||
def reorder_batch(self, input_batch: InputBatch,
|
|
||||||
scheduler_output: SchedulerOutput) -> bool:
|
|
||||||
prompt_list_idx = 0
|
|
||||||
decode_list_idx = 0
|
|
||||||
for req_index in range(input_batch.num_reqs):
|
|
||||||
if input_batch.num_computed_tokens_cpu[
|
|
||||||
req_index] < input_batch.num_prompt_tokens[req_index]:
|
|
||||||
# prompt stage
|
|
||||||
self.reorder_prompt_req_index_list[prompt_list_idx] = req_index
|
|
||||||
prompt_list_idx += 1
|
|
||||||
else:
|
|
||||||
# decode stage
|
|
||||||
self.reorder_decode_req_index_list[decode_list_idx] = req_index
|
|
||||||
decode_list_idx += 1
|
|
||||||
assert decode_list_idx + prompt_list_idx == input_batch.num_reqs
|
|
||||||
|
|
||||||
# Update prompt requests number
|
|
||||||
self.num_prompt_req = prompt_list_idx
|
|
||||||
|
|
||||||
reorder_req_num = 0
|
|
||||||
for req_index in range(decode_list_idx):
|
|
||||||
if self.reorder_decode_req_index_list[req_index] < prompt_list_idx:
|
|
||||||
reorder_req_num += 1
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
if reorder_req_num == 0:
|
|
||||||
return False
|
|
||||||
|
|
||||||
reorder_prompt_list = (
|
|
||||||
self.reorder_prompt_req_index_list[:prompt_list_idx]
|
|
||||||
[-reorder_req_num:])
|
|
||||||
reorder_decode_list = (
|
|
||||||
self.reorder_decode_req_index_list[:decode_list_idx]
|
|
||||||
[:reorder_req_num])
|
|
||||||
assert reorder_decode_list.size == reorder_prompt_list.size
|
|
||||||
|
|
||||||
for idx in range(reorder_req_num):
|
|
||||||
prompt_req_index = reorder_prompt_list[idx].item()
|
|
||||||
decode_req_index = reorder_decode_list[idx].item()
|
|
||||||
input_batch.swap_states(prompt_req_index, decode_req_index)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def build(self,
|
def build(self,
|
||||||
common_prefix_len: int,
|
common_prefix_len: int,
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
@ -397,41 +347,46 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
|
|||||||
|
|
||||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||||
seq_lens_np = seq_lens_cpu.numpy()
|
seq_lens_np = seq_lens_cpu.numpy()
|
||||||
num_prompt_req = self.num_prompt_req
|
|
||||||
max_prefill_seq_len = seq_lens_np[:num_prompt_req].max().item(
|
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||||
) if num_prompt_req > 0 else 0
|
query_start_loc_np = query_start_loc_cpu.numpy()
|
||||||
max_decode_seq_len = seq_lens_np[num_prompt_req:num_reqs].max().item(
|
|
||||||
) if num_prompt_req < num_reqs else 0
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
|
||||||
|
split_decodes_and_prefills(common_attn_metadata,
|
||||||
|
decode_threshold=self.reorder_batch_threshold,
|
||||||
|
require_uniform=True)
|
||||||
|
|
||||||
|
max_prefill_seq_len = seq_lens_np[num_decodes:num_reqs].max().item(
|
||||||
|
) if num_prefills > 0 else 0
|
||||||
|
max_decode_seq_len = seq_lens_np[:num_decodes].max().item(
|
||||||
|
) if num_prefills < num_reqs else 0
|
||||||
self.seq_start_loc_np[0] = 0
|
self.seq_start_loc_np[0] = 0
|
||||||
np.cumsum(seq_lens_np, out=self.seq_start_loc_np[1:num_reqs + 1])
|
np.cumsum(seq_lens_np, out=self.seq_start_loc_np[1:num_reqs + 1])
|
||||||
|
|
||||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
|
||||||
num_prefill_tokens = int(query_start_loc_cpu[num_prompt_req].item())
|
|
||||||
num_decode_tokens = int(query_start_loc_cpu[num_reqs].item() -
|
|
||||||
num_prefill_tokens)
|
|
||||||
|
|
||||||
slot_mapping = common_attn_metadata.slot_mapping.long()
|
slot_mapping = common_attn_metadata.slot_mapping.long()
|
||||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||||
|
query_start_loc_np = query_start_loc_cpu.numpy()
|
||||||
|
query_start_loc_np[num_decodes:num_reqs + 1] -= num_decode_tokens
|
||||||
|
|
||||||
attn_metadata = TorchSDPAMetadata(
|
attn_metadata = TorchSDPAMetadata(
|
||||||
num_prefills=num_prompt_req,
|
num_prefills=num_prefills,
|
||||||
num_prefill_tokens=num_prefill_tokens,
|
num_prefill_tokens=num_prefill_tokens,
|
||||||
num_decode_tokens=num_decode_tokens,
|
num_decode_tokens=num_decode_tokens,
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
# to ensure inference when chunked_prefill is disabled
|
# to ensure inference when chunked_prefill is disabled
|
||||||
seq_lens=seq_lens_cpu.tolist(),
|
seq_lens=seq_lens_cpu.tolist(),
|
||||||
seq_lens_tensor=seq_lens_cpu[num_prompt_req:num_reqs], # decode
|
decode_seq_lens_tensor=seq_lens_cpu[:num_decodes], # decode
|
||||||
max_decode_seq_len=max_decode_seq_len, # decode
|
decode_max_seq_len=max_decode_seq_len, # decode
|
||||||
block_tables=block_table_tensor[num_prompt_req:num_reqs], # decode
|
decode_block_tables=block_table_tensor[:num_decodes], # decode
|
||||||
chunked_prefill=self.scheduler_config.chunked_prefill_enabled,
|
chunked_prefill=self.scheduler_config.chunked_prefill_enabled,
|
||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
max_kv_len=max_prefill_seq_len,
|
prefill_max_seq_len=max_prefill_seq_len,
|
||||||
prefill_query_start_loc=query_start_loc_cpu[:num_prompt_req +
|
prefill_query_start_loc=query_start_loc_cpu[num_decodes:num_reqs +
|
||||||
1], # prefill
|
1], # prefill
|
||||||
kv_start_loc=self.seq_start_loc_cpu[:num_prompt_req +
|
prefill_seq_start_loc=self.seq_start_loc_cpu[num_decodes:num_reqs +
|
||||||
1], # prefill
|
1], # prefill
|
||||||
prefill_block_tables=block_table_tensor[:
|
prefill_block_tables=block_table_tensor[
|
||||||
num_prompt_req], # prefill
|
num_decodes:num_reqs], # prefill
|
||||||
query_start_loc=query_start_loc_cpu[:num_reqs +
|
query_start_loc=query_start_loc_cpu[:num_reqs +
|
||||||
1], # for logits index
|
1], # for logits index
|
||||||
)
|
)
|
||||||
@ -596,14 +551,14 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
|||||||
import intel_extension_for_pytorch.llm.modules as ipex_modules
|
import intel_extension_for_pytorch.llm.modules as ipex_modules
|
||||||
output = torch.empty_like(query)
|
output = torch.empty_like(query)
|
||||||
ipex_modules.PagedAttention.flash_attn_varlen_func(
|
ipex_modules.PagedAttention.flash_attn_varlen_func(
|
||||||
output[:prefill_meta.num_prefill_tokens, :, :],
|
output[prefill_meta.num_decode_tokens:, :, :],
|
||||||
query[:prefill_meta.num_prefill_tokens, :, :],
|
query[prefill_meta.num_decode_tokens:, :, :],
|
||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
prefill_meta.prefill_query_start_loc,
|
prefill_meta.prefill_query_start_loc,
|
||||||
prefill_meta.kv_start_loc,
|
prefill_meta.prefill_seq_start_loc,
|
||||||
prefill_meta.max_query_len,
|
prefill_meta.max_query_len,
|
||||||
prefill_meta.max_kv_len,
|
prefill_meta.prefill_max_seq_len,
|
||||||
self.scale,
|
self.scale,
|
||||||
True,
|
True,
|
||||||
prefill_meta.prefill_block_tables,
|
prefill_meta.prefill_block_tables,
|
||||||
@ -621,8 +576,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
|||||||
) = decode_meta.get_seq_len_block_table_args(attn_type)
|
) = decode_meta.get_seq_len_block_table_args(attn_type)
|
||||||
|
|
||||||
self.paged_attn_impl.forward_decode(
|
self.paged_attn_impl.forward_decode(
|
||||||
output[attn_metadata.num_prefill_tokens:, :, :],
|
output[:attn_metadata.num_decode_tokens, :, :],
|
||||||
query[attn_metadata.num_prefill_tokens:, :, :],
|
query[:attn_metadata.num_decode_tokens, :, :],
|
||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
block_tables_arg,
|
block_tables_arg,
|
||||||
|
|||||||
@ -9,7 +9,6 @@ import torch.nn as nn
|
|||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.v1.attention.backends.cpu_attn import TorchSDPAMetadataBuilderV1
|
|
||||||
from vllm.v1.utils import CpuGpuBuffer
|
from vllm.v1.utils import CpuGpuBuffer
|
||||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||||
|
|
||||||
@ -33,50 +32,12 @@ class CPUModelRunner(GPUModelRunner):
|
|||||||
|
|
||||||
self._postprocess_tensors()
|
self._postprocess_tensors()
|
||||||
|
|
||||||
|
# Note: Remove the override after new attention backend finished
|
||||||
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
|
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
|
||||||
"""
|
|
||||||
Update the order of requests in the batch based on the attention
|
|
||||||
backend's needs. For example, some attention backends (namely MLA) may
|
|
||||||
want to separate requests based on if the attention computation will be
|
|
||||||
compute-bound or memory-bound.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
scheduler_output: The scheduler output.
|
|
||||||
"""
|
|
||||||
# Attention free models have zero kv_cache_groups, however models
|
|
||||||
# like Mamba are also attention free but use the kv_cache for
|
|
||||||
# keeping its internal state. This is why we check the number
|
|
||||||
# of kv_cache groups instead of solely checking
|
|
||||||
# for self.model_config.is_attention_free.
|
|
||||||
if len(self.kv_cache_config.kv_cache_groups) == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
if len(self.kv_cache_config.kv_cache_groups) > 1:
|
if len(self.kv_cache_config.kv_cache_groups) > 1:
|
||||||
raise ValueError("Multiple KVCacheGroups is not"
|
raise ValueError("Multiple KVCacheGroups is not"
|
||||||
"currently supported with CPU model runner.")
|
"currently supported with CPU model runner.")
|
||||||
|
super()._may_reorder_batch(scheduler_output)
|
||||||
# Guard against encoder-only / pooling models where `attn_groups`
|
|
||||||
# may be empty or lack the expected metadata_builder.
|
|
||||||
# Without this check, accessing `attn_groups[0][0]` would trigger
|
|
||||||
# an AssertionError on CPU backend.
|
|
||||||
if not hasattr(self, "attn_groups") or not self.attn_groups:
|
|
||||||
return
|
|
||||||
if not self.attn_groups[0]:
|
|
||||||
return
|
|
||||||
|
|
||||||
mb = getattr(self.attn_groups[0][0], "metadata_builders", None)
|
|
||||||
if isinstance(mb, list):
|
|
||||||
if not isinstance(mb[0], TorchSDPAMetadataBuilderV1):
|
|
||||||
return
|
|
||||||
mb[0].reorder_batch(self.input_batch, scheduler_output)
|
|
||||||
return
|
|
||||||
elif not isinstance(mb, TorchSDPAMetadataBuilderV1):
|
|
||||||
# Encoder-only / rerank models do not benefit from reordering,
|
|
||||||
# so we safely skip here.
|
|
||||||
return
|
|
||||||
|
|
||||||
# Safe path for decoder/attention-heavy models
|
|
||||||
mb.reorder_batch(self.input_batch, scheduler_output)
|
|
||||||
|
|
||||||
def _postprocess_tensors(self) -> None:
|
def _postprocess_tensors(self) -> None:
|
||||||
# Note: replace device tensors with cpu tensors
|
# Note: replace device tensors with cpu tensors
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user