[Hardware][CPU] Support chunked-prefill and prefix-caching on CPU (#10355)

Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
Li, Jiang 2024-11-20 18:57:39 +08:00 committed by GitHub
parent d5b28447e0
commit 63f1fde277
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 558 additions and 368 deletions

View File

@ -25,6 +25,7 @@ docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/hugg
function cpu_tests() {
set -e
export NUMA_NODE=$2
# offline inference
docker exec cpu-test-avx2-"$NUMA_NODE" bash -c "
@ -57,6 +58,12 @@ function cpu_tests() {
pytest -s -v \
tests/quantization/test_ipex_quant.py"
# Run chunked-prefill and prefix-cache test
docker exec cpu-test-"$NUMA_NODE" bash -c "
set -e
pytest -s -v -k cpu_model \
tests/basic_correctness/test_chunked_prefill.py"
# online inference
docker exec cpu-test-"$NUMA_NODE" bash -c "
set -e
@ -75,4 +82,4 @@ function cpu_tests() {
# All of CPU tests are expected to be finished less than 25 mins.
export -f cpu_tests
timeout 25m bash -c "cpu_tests $CORE_RANGE"
timeout 30m bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"

View File

@ -5,11 +5,11 @@ Installation with CPU
vLLM initially supports basic model inferencing and serving on x86 CPU platform, with data types FP32, FP16 and BF16. vLLM CPU backend supports the following vLLM features:
- Tensor Parallel (``-tp = N``)
- Quantization (``INT8 W8A8, AWQ``)
.. note::
More advanced features on `chunked-prefill`, `prefix-caching` and `FP8 KV cache` are under development and will be available soon.
- Tensor Parallel
- Model Quantization (``INT8 W8A8, AWQ``)
- Chunked-prefill
- Prefix-caching
- FP8-E5M2 KV-Caching (TODO)
Table of contents:

View File

@ -344,7 +344,7 @@ Feature x Hardware
- ✅
- ✅
- ✅
-
-
- ✅
* - :ref:`APC <apc>`
- `✗ <https://github.com/vllm-project/vllm/issues/3687>`__
@ -352,7 +352,7 @@ Feature x Hardware
- ✅
- ✅
- ✅
-
-
- ✅
* - :ref:`LoRA <lora>`
- ✅

View File

@ -12,6 +12,7 @@ from contextlib import nullcontext
import pytest
from tests.kernels.utils import override_backend_env_variable
from vllm.platforms import current_platform
from ..models.utils import check_logprobs_close, check_outputs_equal
from ..utils import multi_gpu_test
@ -206,12 +207,14 @@ def test_models_with_fp8_kv_cache(
# NOTE: Increasing this in this suite will fail CI because we currently cannot
# reset distributed env properly. Use a value > 1 just when you test.
@pytest.mark.parametrize("tensor_parallel_size", [1])
@pytest.mark.parametrize("dtype", ["half"])
def test_with_prefix_caching(
vllm_runner,
max_tokens: int,
enforce_eager: bool,
chunk_size: int,
tensor_parallel_size: int,
dtype: str,
) -> None:
"""
Checks exact match decode with and without prefix caching
@ -233,7 +236,7 @@ def test_with_prefix_caching(
for enable in (True, False):
with vllm_runner(
model,
dtype="half",
dtype=dtype,
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=True,
enable_prefix_caching=enable,
@ -260,3 +263,61 @@ def test_with_prefix_caching(
name_0="w/o prefix caching",
name_1="with prefix caching",
)
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
@pytest.mark.parametrize("enforce_eager", [False])
@pytest.mark.parametrize("attention_backend", ["TORCH_SDPA"])
@pytest.mark.cpu_model
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
def test_models_cpu(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
chunked_prefill_token_size: int,
enforce_eager: bool,
attention_backend: str,
monkeypatch,
) -> None:
test_models(
hf_runner,
vllm_runner,
example_prompts,
model,
dtype,
max_tokens,
chunked_prefill_token_size,
enforce_eager,
1,
attention_backend,
monkeypatch,
)
@pytest.mark.parametrize("max_tokens", [16])
@pytest.mark.parametrize("enforce_eager", [False])
@pytest.mark.parametrize("chunk_size", [30, 32])
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.cpu_model
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
def test_with_prefix_caching_cpu(
vllm_runner,
max_tokens: int,
enforce_eager: bool,
chunk_size: int,
dtype: str,
) -> None:
test_with_prefix_caching(
vllm_runner,
max_tokens,
enforce_eager,
chunk_size,
1,
dtype,
)

View File

@ -7,18 +7,14 @@ import torch
from torch.nn.functional import scaled_dot_product_attention
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
AttentionMetadata,
AttentionMetadataBuilder,
AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.ipex_attn import PagedAttention
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
from vllm.platforms import current_platform
if current_platform.is_cpu():
try:
from vllm.attention.ops.ipex_attn import PagedAttention
except ImportError:
from vllm.attention.ops.paged_attn import PagedAttention
else:
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.utils import make_tensor_with_pad
from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder
class TorchSDPABackend(AttentionBackend):
@ -39,6 +35,10 @@ class TorchSDPABackend(AttentionBackend):
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_builder_cls() -> Type["TorchSDPAMetadataBuilder"]:
return TorchSDPAMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
@ -71,9 +71,15 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
slot_mapping: torch.Tensor
seq_lens: Optional[List[int]]
chunked_prefill: bool
seq_lens: Optional[List[int]] = None # For non-chunked prefill
# For chunked prefill only
max_query_len: Optional[int] = None
max_kv_len: Optional[int] = None
query_start_loc: Optional[torch.Tensor] = None
kv_start_loc: Optional[torch.Tensor] = None
prefill_block_tables: Optional[torch.Tensor] = None
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
@ -123,20 +129,14 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
@property
def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]:
# Currently chunked prefill is not supported
if self.num_decode_tokens == 0:
assert self.num_prefills > 0
return self
return None
if self.num_prefill_tokens == 0:
return None
return self
@property
def decode_metadata(self) -> Optional["TorchSDPAMetadata"]:
# Currently chunked prefill is not supported
if self.num_prefills > 0:
assert self.num_decode_tokens == 0
if self.num_decode_tokens == 0:
return None
return self
def get_seq_lens(
@ -274,6 +274,105 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
raise AttributeError(f"Invalid attention type {str(attn_type)}")
class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]):
def __init__(self, input_builder: ModelInputForCPUBuilder) -> None:
self.chunked_prefill = input_builder.chunked_prefill
self.input_data = input_builder.input_data
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int) -> TorchSDPAMetadata:
input_data = self.input_data
prefill_seq_lens = seq_lens[0:input_data.num_prefills]
prefill_query_lens = query_lens[0:input_data.num_prefills]
slot_mapping = torch.tensor(input_data.slot_mapping,
dtype=torch.long,
device="cpu")
# For chunked-prefill
if self.chunked_prefill and input_data.num_prefill_tokens != 0:
prefill_block_tables = make_tensor_with_pad(
self.input_data.prefill_block_tables,
pad=0,
dtype=torch.int32,
device="cpu",
)
query_lens_tensor = torch.tensor(prefill_query_lens,
dtype=torch.int32,
device="cpu")
kv_lens_tensor = torch.tensor(prefill_seq_lens,
dtype=torch.int32,
device="cpu")
query_start_loc = torch.zeros(input_data.num_prefills + 1,
dtype=torch.int32,
device="cpu")
kv_start_loc = torch.zeros(input_data.num_prefills + 1,
dtype=torch.int32,
device="cpu")
torch.cumsum(query_lens_tensor,
dim=0,
dtype=torch.int32,
out=query_start_loc[1:])
torch.cumsum(kv_lens_tensor,
dim=0,
dtype=torch.int32,
out=kv_start_loc[1:])
max_query_len = max(prefill_query_lens)
max_kv_len = max(prefill_seq_lens)
else:
prefill_block_tables = None
query_start_loc = None
kv_start_loc = None
max_query_len = None
max_kv_len = None
# For paged attention
if input_data.num_decode_tokens != 0:
seq_lens_tensor = torch.tensor(
input_data.seq_lens[input_data.num_prefills:],
dtype=torch.int32,
device="cpu",
)
block_tables = make_tensor_with_pad(
self.input_data.decode_block_tables,
pad=0,
dtype=torch.int32,
device="cpu",
)
else:
block_tables = torch.tensor([])
seq_lens_tensor = torch.tensor([])
# For multi-modal models
placeholder_index_maps = None
if len(input_data.multi_modal_inputs_list) != 0:
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
input_data.multi_modal_placeholder_maps.items()
}
attn_metadata = TorchSDPAMetadata(
chunked_prefill=self.chunked_prefill,
seq_lens=prefill_seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_kv_len=max_kv_len,
query_start_loc=query_start_loc,
kv_start_loc=kv_start_loc,
max_decode_seq_len=input_data.max_decode_seq_len,
num_prefills=input_data.num_prefills,
num_prefill_tokens=input_data.num_prefill_tokens,
num_decode_tokens=input_data.num_decode_tokens,
block_tables=block_tables,
prefill_block_tables=prefill_block_tables,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=placeholder_index_maps,
)
return attn_metadata
class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
def __init__(
@ -409,19 +508,35 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
output = torch.empty_like(query)
if prefill_meta := attn_metadata.prefill_metadata:
assert attn_metadata.seq_lens is not None
if (kv_cache.numel() == 0
or prefill_meta.block_tables.numel() == 0):
output = self._run_sdpa_forward(query,
key,
value,
prefill_meta,
attn_type=attn_type)
if not prefill_meta.prefill_metadata.chunked_prefill: # type: ignore
self._run_sdpa_forward(output,
query,
key,
value,
prefill_meta,
attn_type=attn_type)
else:
# prefix-enabled attention
raise RuntimeError(
"Torch SDPA backend doesn't support prefix decoding.")
assert not self.need_mask
import intel_extension_for_pytorch.llm.modules as ipex_modules
output = torch.empty_like(query)
ipex_modules.PagedAttention.flash_attn_varlen_func(
output[:prefill_meta.num_prefill_tokens, :, :],
query[:prefill_meta.num_prefill_tokens, :, :],
key_cache,
value_cache,
prefill_meta.query_start_loc,
prefill_meta.kv_start_loc,
prefill_meta.max_query_len,
prefill_meta.max_kv_len,
self.scale,
True,
prefill_meta.prefill_block_tables,
self.alibi_slopes,
)
if decode_meta := attn_metadata.decode_metadata:
assert attn_type != AttentionType.ENCODER_ONLY, (
@ -433,8 +548,9 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
block_tables_arg,
) = decode_meta.get_seq_len_block_table_args(attn_type)
output = PagedAttention.forward_decode(
query,
PagedAttention.forward_decode(
output[attn_metadata.num_prefill_tokens:, :, :],
query[attn_metadata.num_prefill_tokens:, :, :],
key_cache,
value_cache,
block_tables_arg,
@ -453,12 +569,13 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
def _run_sdpa_forward(
self,
output: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: TorchSDPAMetadata,
attn_type: AttentionType = AttentionType.DECODER,
):
) -> None:
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
value = value.repeat_interleave(self.num_queries_per_kv, dim=1)
@ -479,7 +596,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
attn_masks = [None] * len(seq_lens)
attn_metadata.set_attn_bias(attn_masks, attn_type)
output = torch.empty_like(query)
query = query.movedim(0, query.dim() - 2)
key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2)
@ -502,7 +618,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0)
output[start_q:end_q, :, :] = sub_out
start_q, start_kv = end_q, end_kv
return output
def _make_alibi_bias(

View File

@ -1,12 +1,17 @@
from typing import Dict, List, Optional, Tuple
import intel_extension_for_pytorch.llm.modules as ipex_modules
try:
import intel_extension_for_pytorch.llm.modules as ipex_modules
_use_ipex = True
except ImportError:
_use_ipex = False
import torch
from vllm import _custom_ops as ops
class PagedAttention:
class _PagedAttention:
@staticmethod
def get_supported_head_sizes() -> List[int]:
@ -22,6 +27,105 @@ class PagedAttention:
) -> Tuple[int, ...]:
return (2, num_blocks, block_size * num_kv_heads * head_size)
@staticmethod
def split_kv_cache(
kv_cache: torch.Tensor,
num_kv_heads: int,
head_size: int,
*args,
) -> Tuple[torch.Tensor, torch.Tensor]:
x = 16 // kv_cache.element_size()
num_blocks = kv_cache.shape[1]
key_cache = kv_cache[0]
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
-1, x)
value_cache = kv_cache[1]
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
return key_cache, value_cache
@staticmethod
def write_to_paged_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
*args,
) -> None:
ops.reshape_and_cache(
key,
value,
key_cache,
value_cache,
slot_mapping.flatten(),
kv_cache_dtype,
k_scale,
v_scale,
)
@staticmethod
def forward_decode(
output: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
max_context_len: int,
kv_cache_dtype: str,
num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
k_scale: float,
v_scale: float,
*args,
) -> None:
tp_rank: int = 0
blocksparse_local_blocks: int = 0
blocksparse_vert_stride: int = 0
blocksparse_block_size: int = 64
blocksparse_head_sliding_step: int = 0
block_size = value_cache.shape[3]
ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
*args,
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
ops.copy_blocks(key_caches, value_caches, src_to_dists)
class _IPEXPagedAttention(_PagedAttention):
@staticmethod
def split_kv_cache(
kv_cache: torch.Tensor,
@ -55,6 +159,7 @@ class PagedAttention:
@staticmethod
def forward_decode(
output: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
@ -68,8 +173,7 @@ class PagedAttention:
k_scale: float,
v_scale: float,
*args,
) -> torch.Tensor:
output = torch.empty_like(query)
) -> None:
block_size = value_cache.shape[2]
head_mapping = torch.arange(
0,
@ -83,41 +187,5 @@ class PagedAttention:
scale, block_tables, context_lens, block_size, max_context_len,
alibi_slopes)
return output
@staticmethod
def forward_prefix(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache_dtype: str,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
subquery_start_loc: torch.Tensor,
prompt_lens_tensor: torch.Tensor,
context_lens: torch.Tensor,
max_subquery_len: int,
alibi_slopes: Optional[torch.Tensor],
*args,
) -> torch.Tensor:
raise NotImplementedError
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
*args,
) -> None:
raise NotImplementedError
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
*args,
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
ops.copy_blocks(key_caches, value_caches, src_to_dists)
PagedAttention = _IPEXPagedAttention if _use_ipex else _PagedAttention

View File

@ -53,11 +53,6 @@ class CpuPlatform(Platform):
cache_config = vllm_config.cache_config
if cache_config.enable_prefix_caching:
logger.warning(
"Prefix caching is not supported on CPU, disable it.")
cache_config.enable_prefix_caching = False
kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE
if kv_cache_space >= 0:
@ -74,10 +69,12 @@ class CpuPlatform(Platform):
f" {kv_cache_space}, expect a positive integer value.")
scheduler_config = vllm_config.scheduler_config
if scheduler_config.chunked_prefill_enabled:
logger.warning(
"Chunked prefill is not supported on CPU, disable it.")
scheduler_config.chunked_prefill_enabled = False
if ((scheduler_config.chunked_prefill_enabled
or cache_config.enable_prefix_caching)
and model_config.dtype == torch.half):
logger.warning("Chunked-prefill on the CPU backend only does not"
" support fp16 for now, cast to bf16.")
model_config.dtype = torch.bfloat16
parallel_config = vllm_config.parallel_config
if (parallel_config.distributed_executor_backend is not None

View File

@ -2,8 +2,8 @@ import dataclasses
import weakref
from collections import defaultdict
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type,
TypeVar, Union)
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Type, TypeVar,
Union)
import torch
from torch import nn
@ -19,7 +19,6 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalKwargs, MultiModalPlaceholderMap)
from vllm.sequence import (IntermediateTensors, SequenceData,
SequenceGroupMetadata)
from vllm.utils import make_tensor_with_pad
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
_add_attn_metadata_broadcastable_dict,
@ -104,65 +103,223 @@ class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU):
class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
class ModelInputData:
def __init__(self, use_mrope: bool):
self.use_mrope = use_mrope
self.input_tokens: List[int] = []
self.input_positions: Optional[
List[int]] = [] if not self.use_mrope else None
self.seq_lens: List[int] = []
self.query_lens: List[int] = []
self.prefill_block_tables: List[List[int]] = []
self.decode_block_tables: List[List[int]] = []
self.max_decode_seq_len: int = 0
self.num_prefills: int = 0
self.num_prefill_tokens: int = 0
self.num_decode_tokens: int = 0
self.slot_mapping: List[int] = []
self.multi_modal_inputs_list: List[MultiModalKwargs] = []
self.multi_modal_placeholder_maps: Dict[
str, MultiModalPlaceholderMap] = defaultdict(
MultiModalPlaceholderMap)
self.input_mrope_positions: Optional[List[List[int]]] = [
[] for _ in range(3)
] if self.use_mrope else None
def __init__(self,
runner: "CPUModelRunner",
finished_requests_ids: Optional[List[str]] = None) -> None:
super().__init__()
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
self.runner = runner
self.chunked_prefill = (runner.scheduler_config.chunked_prefill_enabled
or runner.cache_config.enable_prefix_caching)
self.model_input_cls = self.runner._model_input_cls
self.attn_backend = self.runner.attn_backend
self.sliding_window = self.runner.sliding_window
self.block_size = self.runner.block_size
self.device = self.runner.device
self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
self.input_data = ModelInputForCPUBuilder.ModelInputData(
self.runner.model_config.uses_mrope)
self.att_metadata_builder = self.runner.attn_backend.get_builder_cls()(
self)
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
self.seq_group_metadata_list.append(seq_group_metadata)
def set_seq_group_list(
self, seq_group_metadata_list: List[SequenceGroupMetadata]):
self.seq_group_metadata_list = seq_group_metadata_list
def build(self) -> ModelInputForCPU:
self._build_input_data()
input_data = self.input_data
input_tokens = torch.tensor(input_data.input_tokens,
dtype=torch.long,
device="cpu")
input_positions = torch.tensor(
input_data.input_positions
if not input_data.use_mrope else input_data.input_mrope_positions,
dtype=torch.long,
device="cpu")
# For multi-modal models
multi_modal_kwargs = None
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt = self.seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_kwargs) = self._prepare_prompt(
self.seq_group_metadata_list)
else:
(input_tokens, input_positions,
attn_metadata) = self._prepare_decode(
self.seq_group_metadata_list)
seq_lens = None
if len(input_data.multi_modal_inputs_list) != 0:
multi_modal_kwargs = MultiModalKwargs.batch(
input_data.multi_modal_inputs_list)
attn_metadata = self.att_metadata_builder.build(
input_data.seq_lens, input_data.query_lens, -1, -1)
return self.model_input_cls(
input_tokens=input_tokens,
input_positions=input_positions,
seq_lens=input_data.seq_lens,
query_lens=input_data.query_lens,
attn_metadata=attn_metadata,
multi_modal_kwargs=multi_modal_kwargs,
# query_lens is not needed if chunked prefill is not
# supported. Since CPU worker doesn't support chunked prefill
# just use seq_lens instead.
seq_lens=seq_lens,
query_lens=seq_lens,
)
def _compute_multi_modal_input(
self,
seq_data: SequenceData,
computed_len: int,
seq_group_metadata: SequenceGroupMetadata,
):
def _build_input_data(self):
for seq_group_metadata in self.seq_group_metadata_list:
for seq_id, seq_data in seq_group_metadata.seq_data.items():
if seq_group_metadata.is_prompt:
self._compute_prompt_input_tokens(self.input_data,
seq_group_metadata,
seq_data, seq_id)
if seq_group_metadata.multi_modal_data:
self._compute_multi_modal_input(
seq_group_metadata, seq_data)
else:
self._compute_decode_input_tokens(self.input_data,
seq_group_metadata,
seq_data, seq_id)
def _compute_decode_input_tokens(self, data: ModelInputData,
seq_group_metadata: SequenceGroupMetadata,
seq_data: SequenceData, seq_id: int):
"""
Compute decode input tokens, positions, block table and slot mapping.
"""
block_size = self.runner.block_size
block_table = seq_group_metadata.block_tables[seq_id]
seq_len = seq_data.get_len()
context_len = seq_data.get_num_computed_tokens()
tokens = seq_data.get_last_token_id()
token_positions = seq_len - 1
block_number = block_table[token_positions // block_size]
block_offset = token_positions % block_size
slot = block_number * block_size + block_offset
# For paged_attention kernel
if self.runner.sliding_window:
start_idx = max(0, seq_len - self.runner.sliding_window)
start_block = start_idx // block_size
start_idx = start_block * block_size
seq_len = seq_len - start_idx
block_table = block_table[start_block:]
# For MRotaryEmbedding
if data.input_positions is None:
next_pos = MRotaryEmbedding.get_next_input_positions(
seq_data.mrope_position_delta,
context_len,
seq_len,
)
for idx in range(3):
data.input_mrope_positions[idx].extend( # type: ignore
next_pos[idx])
else:
data.input_positions.append(token_positions) # type: ignore
# Update fields
data.input_tokens.append(tokens)
data.max_decode_seq_len = max(data.max_decode_seq_len, seq_len)
data.num_decode_tokens += 1
data.slot_mapping.append(slot)
data.decode_block_tables.append(block_table)
data.query_lens.append(1)
data.seq_lens.append(seq_len)
def _compute_prompt_input_tokens(self, data: ModelInputData,
seq_group_metadata: SequenceGroupMetadata,
seq_data: SequenceData, seq_id: int):
"""
Compute prompt input tokens, positions, block table and slot mapping.
"""
token_chunk_size = seq_group_metadata.token_chunk_size
block_size = self.runner.block_size
block_table = seq_group_metadata.block_tables[seq_id]
seq_len = seq_data.get_len()
context_len = seq_data.get_num_computed_tokens()
seq_len = min(seq_len, context_len + token_chunk_size)
# For prefix caching
prefix_cache_block_num = len(seq_group_metadata.computed_block_nums)
if prefix_cache_block_num > 0:
prefix_cache_len = (prefix_cache_block_num *
self.runner.block_size)
if prefix_cache_len <= context_len:
# We already passed the cache hit region,
# so do normal computation.
pass
elif context_len < prefix_cache_len < seq_len:
# Partial hit. Compute the missing part.
context_len = prefix_cache_len
token_chunk_size = seq_len - context_len
elif seq_len <= prefix_cache_len:
# Full hit. Only compute the last token to avoid
# erroneous behavior. FIXME: Ideally we should directly
# mark all tokens as computed in the scheduler and do not
# schedule this sequence, so this case should not happen.
context_len = seq_len - 1
token_chunk_size = 1
tokens = seq_data.get_token_ids()
tokens = tokens[context_len:seq_len]
token_positions = range(context_len, seq_len)
# For encoder-only models, the block_table is None,
# and there is no need to initialize the slot_mapping.
if block_table is not None:
slot_mapping = [_PAD_SLOT_ID] * len(token_positions)
for i, pos in enumerate(token_positions):
block_number = block_table[pos // block_size]
block_offset = pos % block_size
slot = block_number * block_size + block_offset
slot_mapping[i] = slot
data.slot_mapping.extend(slot_mapping)
# The MROPE positions are prepared in _compute_multi_modal_input
if data.input_positions is not None:
data.input_positions.extend(token_positions)
# Update fields
data.input_tokens.extend(tokens)
data.num_prefills += 1
data.num_prefill_tokens += len(tokens)
data.query_lens.append(len(tokens))
data.prefill_block_tables.append(block_table)
data.seq_lens.append(seq_len)
def _compute_multi_modal_input(self,
seq_group_metadata: SequenceGroupMetadata,
seq_data: SequenceData):
computed_len = seq_data.get_num_computed_tokens()
seq_len = self.input_data.seq_lens[-1]
# NOTE: mm_data only includes the subset of multi-modal items that
# intersect with the current prefill positions.
mm_data, placeholder_maps = MultiModalPlaceholderMap.from_seq_group(
seq_group_metadata,
range(computed_len, len(seq_data.get_token_ids())),
)
seq_group_metadata, range(computed_len, seq_len))
if not mm_data:
return None, None, None
return
if self.runner.mm_registry.has_processor(self.runner.model_config):
mm_kwargs = mm_data
@ -173,8 +330,10 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
)
# special processing for mrope position deltas.
mrope_positions = None
if self.runner.model_config.uses_mrope:
assert not self.chunked_prefill, \
"MROPE on CPU does not support chunked-prefill."
image_grid_thw = mm_kwargs.get("image_grid_thw", None)
video_grid_thw = mm_kwargs.get("video_grid_thw", None)
assert image_grid_thw is not None or video_grid_thw is not None, (
@ -198,226 +357,15 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
context_len=computed_len,
)
seq_data.mrope_position_delta = mrope_position_delta
return mm_kwargs, placeholder_maps, mrope_positions
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
BatchedTensorInputs]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = []
input_positions: List[int] = []
input_mrope_positions: List[List[int]] = [[] for _ in range(3)]
for i in range(3):
self.input_data.input_mrope_positions[ # type: ignore
i].extend(mrope_positions[i])
slot_mapping: List[int] = []
seq_lens: List[int] = []
multi_modal_kwargs_list: List[MultiModalKwargs] = []
multi_modal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
assert len(seq_ids) == 1
seq_id = seq_ids[0]
seq_data = seq_group_metadata.seq_data[seq_id]
prompt_tokens = seq_data.get_token_ids()
computed_len = seq_data.get_num_computed_tokens()
seq_len = len(prompt_tokens)
seq_lens.append(seq_len) # Prompt token num
input_tokens.extend(prompt_tokens) # Token ids
mrope_positions = None
if seq_group_metadata.multi_modal_data:
(
mm_kwargs,
placeholder_maps,
mrope_positions,
) = self._compute_multi_modal_input(seq_data, computed_len,
seq_group_metadata)
multi_modal_kwargs_list.append(mm_kwargs)
for modality, placeholder_map in placeholder_maps.items():
multi_modal_placeholder_maps[modality].extend(
placeholder_map)
# Token position ids
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
if mrope_positions:
for idx in range(3):
input_mrope_positions[idx].extend(mrope_positions[idx])
else:
input_positions.extend(list(range(computed_len, seq_len)))
# Compute the slot mapping.
block_table = seq_group_metadata.block_tables[seq_id]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0, seq_len - sliding_window).
# For example, if the prompt len is 10, sliding window is 8, and
# block size is 4, the first two tokens are masked and the slot
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx = 0
if self.sliding_window is not None:
start_idx = max(0, seq_len - self.sliding_window)
for i in range(computed_len, seq_len):
if i < start_idx:
slot_mapping.append(_PAD_SLOT_ID)
continue
# For encoder-only models, the block_table is None,
# and there is no need to initialize the slot_mapping.
if block_table is not None:
block_number = block_table[i //
self.block_size] # type: ignore
block_offset = i % self.block_size # type: ignore
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
if any(input_mrope_positions):
input_positions = None # type: ignore
else:
input_mrope_positions = None # type: ignore
num_prompt_tokens = len(input_tokens)
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device) # type: ignore
input_positions = torch.tensor(input_positions
or input_mrope_positions,
dtype=torch.long,
device=self.device) # type: ignore
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device) # type: ignore
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
multi_modal_placeholder_maps.items()
}
attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
seq_lens=seq_lens,
seq_lens_tensor=torch.tensor([]),
max_decode_seq_len=0,
num_prefills=len(seq_lens),
num_prefill_tokens=num_prompt_tokens,
num_decode_tokens=0,
block_tables=torch.tensor([]),
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=placeholder_index_maps,
)
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
return (input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_kwargs)
def _prepare_decode(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = []
input_positions: List[int] = []
input_mrope_positions: List[List[int]] = [[] for _ in range(3)]
slot_mapping: List[int] = []
seq_lens: List[int] = []
block_tables: List[List[int]] = []
for seq_group_metadata in seq_group_metadata_list:
assert not seq_group_metadata.is_prompt
assert seq_group_metadata.token_chunk_size == 1
seq_ids = list(seq_group_metadata.seq_data.keys())
for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id()
input_tokens.append(generation_token)
seq_len = seq_data.get_len()
position = seq_len - 1
if seq_data.mrope_position_delta is not None:
context_len = seq_data.get_num_computed_tokens()
next_pos = MRotaryEmbedding.get_next_input_positions(
seq_data.mrope_position_delta,
context_len,
seq_len,
)
for idx in range(3):
input_mrope_positions[idx].extend(next_pos[idx])
else:
input_positions.append(position)
seq_len = seq_len if self.sliding_window is None else min(
seq_len, self.sliding_window)
seq_lens.append(seq_len)
block_table = seq_group_metadata.block_tables[seq_id]
block_number = block_table[position // self.block_size]
block_offset = position % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
if self.sliding_window is not None:
sliding_window_blocks = (self.sliding_window //
self.block_size)
block_table = block_table[-sliding_window_blocks:]
block_tables.append(block_table)
if any(input_mrope_positions):
input_positions = None # type: ignore
else:
input_mrope_positions = None # type: ignore
max_decode_seq_len = max(seq_lens)
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
input_positions = torch.tensor(input_positions
or input_mrope_positions,
dtype=torch.long,
device=self.device)
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device)
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=self.device)
block_tables = make_tensor_with_pad(
block_tables,
pad=0,
dtype=torch.int,
device=self.device,
)
attn_metadata = self.attn_backend.make_metadata(
is_prompt=False,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_decode_seq_len=max_decode_seq_len,
num_prefill_tokens=0,
num_decode_tokens=len(input_tokens),
num_prefills=0,
block_tables=block_tables,
)
return (
input_tokens,
input_positions,
attn_metadata,
)
self.input_data.multi_modal_inputs_list.append(mm_kwargs)
for modality, placeholder_map in placeholder_maps.items():
self.input_data.multi_modal_placeholder_maps[modality].extend(
placeholder_map)
class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
@ -436,8 +384,6 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
**kwargs,
):
ModelRunnerBase.__init__(self, vllm_config)
# Currently, CPU worker doesn't support chunked prefill.
assert self.scheduler_config.chunked_prefill_enabled is False
model_config = self.model_config
cache_config = self.cache_config
@ -479,8 +425,7 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
"""
builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
for seq_group_metadata in seq_group_metadata_list:
builder.add_seq_group(seq_group_metadata)
builder.set_seq_group_list(seq_group_metadata_list)
return builder.build() # type: ignore
@ -537,22 +482,19 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
"CPU worker does not support multi-step execution.")
model_executable = self.model
execute_model_kwargs = {
"input_ids":
model_input.input_tokens,
"positions":
model_input.input_positions,
"kv_caches":
kv_caches,
"attn_metadata":
model_input.attn_metadata,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device),
"intermediate_tensors":
intermediate_tensors,
}
multimodal_kwargs = {}
if model_input.multi_modal_kwargs is not None:
multimodal_kwargs = MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs, device=self.device)
hidden_states = model_executable(**execute_model_kwargs)
hidden_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**multimodal_kwargs,
)
# Compute the logits.
logits = self.model.compute_logits(hidden_states,