mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:55:40 +08:00
[Hardware][CPU] Support chunked-prefill and prefix-caching on CPU (#10355)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
parent
d5b28447e0
commit
63f1fde277
@ -25,6 +25,7 @@ docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/hugg
|
||||
|
||||
function cpu_tests() {
|
||||
set -e
|
||||
export NUMA_NODE=$2
|
||||
|
||||
# offline inference
|
||||
docker exec cpu-test-avx2-"$NUMA_NODE" bash -c "
|
||||
@ -57,6 +58,12 @@ function cpu_tests() {
|
||||
pytest -s -v \
|
||||
tests/quantization/test_ipex_quant.py"
|
||||
|
||||
# Run chunked-prefill and prefix-cache test
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
pytest -s -v -k cpu_model \
|
||||
tests/basic_correctness/test_chunked_prefill.py"
|
||||
|
||||
# online inference
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
@ -75,4 +82,4 @@ function cpu_tests() {
|
||||
|
||||
# All of CPU tests are expected to be finished less than 25 mins.
|
||||
export -f cpu_tests
|
||||
timeout 25m bash -c "cpu_tests $CORE_RANGE"
|
||||
timeout 30m bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"
|
||||
|
||||
@ -5,11 +5,11 @@ Installation with CPU
|
||||
|
||||
vLLM initially supports basic model inferencing and serving on x86 CPU platform, with data types FP32, FP16 and BF16. vLLM CPU backend supports the following vLLM features:
|
||||
|
||||
- Tensor Parallel (``-tp = N``)
|
||||
- Quantization (``INT8 W8A8, AWQ``)
|
||||
|
||||
.. note::
|
||||
More advanced features on `chunked-prefill`, `prefix-caching` and `FP8 KV cache` are under development and will be available soon.
|
||||
- Tensor Parallel
|
||||
- Model Quantization (``INT8 W8A8, AWQ``)
|
||||
- Chunked-prefill
|
||||
- Prefix-caching
|
||||
- FP8-E5M2 KV-Caching (TODO)
|
||||
|
||||
Table of contents:
|
||||
|
||||
|
||||
@ -344,7 +344,7 @@ Feature x Hardware
|
||||
- ✅
|
||||
- ✅
|
||||
- ✅
|
||||
- ✗
|
||||
- ✅
|
||||
- ✅
|
||||
* - :ref:`APC <apc>`
|
||||
- `✗ <https://github.com/vllm-project/vllm/issues/3687>`__
|
||||
@ -352,7 +352,7 @@ Feature x Hardware
|
||||
- ✅
|
||||
- ✅
|
||||
- ✅
|
||||
- ✗
|
||||
- ✅
|
||||
- ✅
|
||||
* - :ref:`LoRA <lora>`
|
||||
- ✅
|
||||
|
||||
@ -12,6 +12,7 @@ from contextlib import nullcontext
|
||||
import pytest
|
||||
|
||||
from tests.kernels.utils import override_backend_env_variable
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..models.utils import check_logprobs_close, check_outputs_equal
|
||||
from ..utils import multi_gpu_test
|
||||
@ -206,12 +207,14 @@ def test_models_with_fp8_kv_cache(
|
||||
# NOTE: Increasing this in this suite will fail CI because we currently cannot
|
||||
# reset distributed env properly. Use a value > 1 just when you test.
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1])
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_with_prefix_caching(
|
||||
vllm_runner,
|
||||
max_tokens: int,
|
||||
enforce_eager: bool,
|
||||
chunk_size: int,
|
||||
tensor_parallel_size: int,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
"""
|
||||
Checks exact match decode with and without prefix caching
|
||||
@ -233,7 +236,7 @@ def test_with_prefix_caching(
|
||||
for enable in (True, False):
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype="half",
|
||||
dtype=dtype,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
enable_chunked_prefill=True,
|
||||
enable_prefix_caching=enable,
|
||||
@ -260,3 +263,61 @@ def test_with_prefix_caching(
|
||||
name_0="w/o prefix caching",
|
||||
name_1="with prefix caching",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
|
||||
@pytest.mark.parametrize("enforce_eager", [False])
|
||||
@pytest.mark.parametrize("attention_backend", ["TORCH_SDPA"])
|
||||
@pytest.mark.cpu_model
|
||||
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
|
||||
def test_models_cpu(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
chunked_prefill_token_size: int,
|
||||
enforce_eager: bool,
|
||||
attention_backend: str,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model,
|
||||
dtype,
|
||||
max_tokens,
|
||||
chunked_prefill_token_size,
|
||||
enforce_eager,
|
||||
1,
|
||||
attention_backend,
|
||||
monkeypatch,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("max_tokens", [16])
|
||||
@pytest.mark.parametrize("enforce_eager", [False])
|
||||
@pytest.mark.parametrize("chunk_size", [30, 32])
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.cpu_model
|
||||
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
|
||||
def test_with_prefix_caching_cpu(
|
||||
vllm_runner,
|
||||
max_tokens: int,
|
||||
enforce_eager: bool,
|
||||
chunk_size: int,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
test_with_prefix_caching(
|
||||
vllm_runner,
|
||||
max_tokens,
|
||||
enforce_eager,
|
||||
chunk_size,
|
||||
1,
|
||||
dtype,
|
||||
)
|
||||
|
||||
@ -7,18 +7,14 @@ import torch
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_cpu():
|
||||
try:
|
||||
from vllm.attention.ops.ipex_attn import PagedAttention
|
||||
except ImportError:
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
else:
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
|
||||
from vllm.utils import make_tensor_with_pad
|
||||
from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder
|
||||
|
||||
|
||||
class TorchSDPABackend(AttentionBackend):
|
||||
@ -39,6 +35,10 @@ class TorchSDPABackend(AttentionBackend):
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["TorchSDPAMetadataBuilder"]:
|
||||
return TorchSDPAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
@ -71,9 +71,15 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
"""
|
||||
# Currently, input sequences can only contain all prompts
|
||||
# or all decoding. True if all sequences are prompts.
|
||||
is_prompt: bool
|
||||
slot_mapping: torch.Tensor
|
||||
seq_lens: Optional[List[int]]
|
||||
chunked_prefill: bool
|
||||
seq_lens: Optional[List[int]] = None # For non-chunked prefill
|
||||
|
||||
# For chunked prefill only
|
||||
max_query_len: Optional[int] = None
|
||||
max_kv_len: Optional[int] = None
|
||||
query_start_loc: Optional[torch.Tensor] = None
|
||||
kv_start_loc: Optional[torch.Tensor] = None
|
||||
prefill_block_tables: Optional[torch.Tensor] = None
|
||||
|
||||
# Begin encoder attn & enc/dec cross-attn fields...
|
||||
# Encoder sequence lengths representation
|
||||
@ -123,20 +129,14 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]:
|
||||
# Currently chunked prefill is not supported
|
||||
if self.num_decode_tokens == 0:
|
||||
assert self.num_prefills > 0
|
||||
return self
|
||||
|
||||
if self.num_prefill_tokens == 0:
|
||||
return None
|
||||
return self
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["TorchSDPAMetadata"]:
|
||||
# Currently chunked prefill is not supported
|
||||
if self.num_prefills > 0:
|
||||
assert self.num_decode_tokens == 0
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
return self
|
||||
|
||||
def get_seq_lens(
|
||||
@ -274,6 +274,105 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
|
||||
class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]):
|
||||
|
||||
def __init__(self, input_builder: ModelInputForCPUBuilder) -> None:
|
||||
self.chunked_prefill = input_builder.chunked_prefill
|
||||
self.input_data = input_builder.input_data
|
||||
|
||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||
cuda_graph_pad_size: int, batch_size: int) -> TorchSDPAMetadata:
|
||||
input_data = self.input_data
|
||||
prefill_seq_lens = seq_lens[0:input_data.num_prefills]
|
||||
prefill_query_lens = query_lens[0:input_data.num_prefills]
|
||||
slot_mapping = torch.tensor(input_data.slot_mapping,
|
||||
dtype=torch.long,
|
||||
device="cpu")
|
||||
|
||||
# For chunked-prefill
|
||||
if self.chunked_prefill and input_data.num_prefill_tokens != 0:
|
||||
prefill_block_tables = make_tensor_with_pad(
|
||||
self.input_data.prefill_block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
query_lens_tensor = torch.tensor(prefill_query_lens,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
kv_lens_tensor = torch.tensor(prefill_seq_lens,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
query_start_loc = torch.zeros(input_data.num_prefills + 1,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
kv_start_loc = torch.zeros(input_data.num_prefills + 1,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
torch.cumsum(query_lens_tensor,
|
||||
dim=0,
|
||||
dtype=torch.int32,
|
||||
out=query_start_loc[1:])
|
||||
torch.cumsum(kv_lens_tensor,
|
||||
dim=0,
|
||||
dtype=torch.int32,
|
||||
out=kv_start_loc[1:])
|
||||
max_query_len = max(prefill_query_lens)
|
||||
max_kv_len = max(prefill_seq_lens)
|
||||
else:
|
||||
prefill_block_tables = None
|
||||
query_start_loc = None
|
||||
kv_start_loc = None
|
||||
max_query_len = None
|
||||
max_kv_len = None
|
||||
|
||||
# For paged attention
|
||||
if input_data.num_decode_tokens != 0:
|
||||
seq_lens_tensor = torch.tensor(
|
||||
input_data.seq_lens[input_data.num_prefills:],
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
block_tables = make_tensor_with_pad(
|
||||
self.input_data.decode_block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
else:
|
||||
block_tables = torch.tensor([])
|
||||
seq_lens_tensor = torch.tensor([])
|
||||
|
||||
# For multi-modal models
|
||||
placeholder_index_maps = None
|
||||
if len(input_data.multi_modal_inputs_list) != 0:
|
||||
placeholder_index_maps = {
|
||||
modality: placeholder_map.index_map()
|
||||
for modality, placeholder_map in
|
||||
input_data.multi_modal_placeholder_maps.items()
|
||||
}
|
||||
|
||||
attn_metadata = TorchSDPAMetadata(
|
||||
chunked_prefill=self.chunked_prefill,
|
||||
seq_lens=prefill_seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=max_query_len,
|
||||
max_kv_len=max_kv_len,
|
||||
query_start_loc=query_start_loc,
|
||||
kv_start_loc=kv_start_loc,
|
||||
max_decode_seq_len=input_data.max_decode_seq_len,
|
||||
num_prefills=input_data.num_prefills,
|
||||
num_prefill_tokens=input_data.num_prefill_tokens,
|
||||
num_decode_tokens=input_data.num_decode_tokens,
|
||||
block_tables=block_tables,
|
||||
prefill_block_tables=prefill_block_tables,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
)
|
||||
|
||||
return attn_metadata
|
||||
|
||||
|
||||
class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
|
||||
def __init__(
|
||||
@ -409,19 +508,35 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
|
||||
output = torch.empty_like(query)
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
assert attn_metadata.seq_lens is not None
|
||||
if (kv_cache.numel() == 0
|
||||
or prefill_meta.block_tables.numel() == 0):
|
||||
output = self._run_sdpa_forward(query,
|
||||
if not prefill_meta.prefill_metadata.chunked_prefill: # type: ignore
|
||||
self._run_sdpa_forward(output,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
prefill_meta,
|
||||
attn_type=attn_type)
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
raise RuntimeError(
|
||||
"Torch SDPA backend doesn't support prefix decoding.")
|
||||
assert not self.need_mask
|
||||
import intel_extension_for_pytorch.llm.modules as ipex_modules
|
||||
output = torch.empty_like(query)
|
||||
ipex_modules.PagedAttention.flash_attn_varlen_func(
|
||||
output[:prefill_meta.num_prefill_tokens, :, :],
|
||||
query[:prefill_meta.num_prefill_tokens, :, :],
|
||||
key_cache,
|
||||
value_cache,
|
||||
prefill_meta.query_start_loc,
|
||||
prefill_meta.kv_start_loc,
|
||||
prefill_meta.max_query_len,
|
||||
prefill_meta.max_kv_len,
|
||||
self.scale,
|
||||
True,
|
||||
prefill_meta.prefill_block_tables,
|
||||
self.alibi_slopes,
|
||||
)
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
assert attn_type != AttentionType.ENCODER_ONLY, (
|
||||
@ -433,8 +548,9 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
block_tables_arg,
|
||||
) = decode_meta.get_seq_len_block_table_args(attn_type)
|
||||
|
||||
output = PagedAttention.forward_decode(
|
||||
query,
|
||||
PagedAttention.forward_decode(
|
||||
output[attn_metadata.num_prefill_tokens:, :, :],
|
||||
query[attn_metadata.num_prefill_tokens:, :, :],
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_tables_arg,
|
||||
@ -453,12 +569,13 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
|
||||
def _run_sdpa_forward(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_metadata: TorchSDPAMetadata,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
):
|
||||
) -> None:
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
|
||||
value = value.repeat_interleave(self.num_queries_per_kv, dim=1)
|
||||
@ -479,7 +596,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
attn_masks = [None] * len(seq_lens)
|
||||
attn_metadata.set_attn_bias(attn_masks, attn_type)
|
||||
|
||||
output = torch.empty_like(query)
|
||||
query = query.movedim(0, query.dim() - 2)
|
||||
key = key.movedim(0, key.dim() - 2)
|
||||
value = value.movedim(0, value.dim() - 2)
|
||||
@ -502,7 +618,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0)
|
||||
output[start_q:end_q, :, :] = sub_out
|
||||
start_q, start_kv = end_q, end_kv
|
||||
return output
|
||||
|
||||
|
||||
def _make_alibi_bias(
|
||||
|
||||
@ -1,12 +1,17 @@
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch.llm.modules as ipex_modules
|
||||
_use_ipex = True
|
||||
except ImportError:
|
||||
_use_ipex = False
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
|
||||
class PagedAttention:
|
||||
class _PagedAttention:
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> List[int]:
|
||||
@ -22,6 +27,105 @@ class PagedAttention:
|
||||
) -> Tuple[int, ...]:
|
||||
return (2, num_blocks, block_size * num_kv_heads * head_size)
|
||||
|
||||
@staticmethod
|
||||
def split_kv_cache(
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
*args,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
x = 16 // kv_cache.element_size()
|
||||
num_blocks = kv_cache.shape[1]
|
||||
|
||||
key_cache = kv_cache[0]
|
||||
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
|
||||
-1, x)
|
||||
value_cache = kv_cache[1]
|
||||
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
|
||||
return key_cache, value_cache
|
||||
|
||||
@staticmethod
|
||||
def write_to_paged_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
*args,
|
||||
) -> None:
|
||||
ops.reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping.flatten(),
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def forward_decode(
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
max_context_len: int,
|
||||
kv_cache_dtype: str,
|
||||
num_kv_heads: int,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
*args,
|
||||
) -> None:
|
||||
tp_rank: int = 0
|
||||
blocksparse_local_blocks: int = 0
|
||||
blocksparse_vert_stride: int = 0
|
||||
blocksparse_block_size: int = 64
|
||||
blocksparse_head_sliding_step: int = 0
|
||||
block_size = value_cache.shape[3]
|
||||
|
||||
ops.paged_attention_v1(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
block_size,
|
||||
max_context_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
tp_rank,
|
||||
blocksparse_local_blocks,
|
||||
blocksparse_vert_stride,
|
||||
blocksparse_block_size,
|
||||
blocksparse_head_sliding_step,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: Dict[int, List[int]],
|
||||
*args,
|
||||
) -> None:
|
||||
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
||||
value_caches = [kv_cache[1] for kv_cache in kv_caches]
|
||||
ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
||||
|
||||
|
||||
class _IPEXPagedAttention(_PagedAttention):
|
||||
|
||||
@staticmethod
|
||||
def split_kv_cache(
|
||||
kv_cache: torch.Tensor,
|
||||
@ -55,6 +159,7 @@ class PagedAttention:
|
||||
|
||||
@staticmethod
|
||||
def forward_decode(
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
@ -68,8 +173,7 @@ class PagedAttention:
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
*args,
|
||||
) -> torch.Tensor:
|
||||
output = torch.empty_like(query)
|
||||
) -> None:
|
||||
block_size = value_cache.shape[2]
|
||||
head_mapping = torch.arange(
|
||||
0,
|
||||
@ -83,41 +187,5 @@ class PagedAttention:
|
||||
scale, block_tables, context_lens, block_size, max_context_len,
|
||||
alibi_slopes)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def forward_prefix(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
subquery_start_loc: torch.Tensor,
|
||||
prompt_lens_tensor: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
max_subquery_len: int,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
*args,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: Dict[int, int],
|
||||
*args,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: Dict[int, List[int]],
|
||||
*args,
|
||||
) -> None:
|
||||
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
||||
value_caches = [kv_cache[1] for kv_cache in kv_caches]
|
||||
ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
||||
PagedAttention = _IPEXPagedAttention if _use_ipex else _PagedAttention
|
||||
|
||||
@ -53,11 +53,6 @@ class CpuPlatform(Platform):
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
|
||||
if cache_config.enable_prefix_caching:
|
||||
logger.warning(
|
||||
"Prefix caching is not supported on CPU, disable it.")
|
||||
cache_config.enable_prefix_caching = False
|
||||
|
||||
kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE
|
||||
|
||||
if kv_cache_space >= 0:
|
||||
@ -74,10 +69,12 @@ class CpuPlatform(Platform):
|
||||
f" {kv_cache_space}, expect a positive integer value.")
|
||||
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
if scheduler_config.chunked_prefill_enabled:
|
||||
logger.warning(
|
||||
"Chunked prefill is not supported on CPU, disable it.")
|
||||
scheduler_config.chunked_prefill_enabled = False
|
||||
if ((scheduler_config.chunked_prefill_enabled
|
||||
or cache_config.enable_prefix_caching)
|
||||
and model_config.dtype == torch.half):
|
||||
logger.warning("Chunked-prefill on the CPU backend only does not"
|
||||
" support fp16 for now, cast to bf16.")
|
||||
model_config.dtype = torch.bfloat16
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
if (parallel_config.distributed_executor_backend is not None
|
||||
|
||||
@ -2,8 +2,8 @@ import dataclasses
|
||||
import weakref
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type,
|
||||
TypeVar, Union)
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Type, TypeVar,
|
||||
Union)
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -19,7 +19,6 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
||||
MultiModalKwargs, MultiModalPlaceholderMap)
|
||||
from vllm.sequence import (IntermediateTensors, SequenceData,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.utils import make_tensor_with_pad
|
||||
from vllm.worker.model_runner_base import (
|
||||
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
|
||||
_add_attn_metadata_broadcastable_dict,
|
||||
@ -104,65 +103,223 @@ class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU):
|
||||
|
||||
class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
||||
|
||||
class ModelInputData:
|
||||
|
||||
def __init__(self, use_mrope: bool):
|
||||
self.use_mrope = use_mrope
|
||||
self.input_tokens: List[int] = []
|
||||
self.input_positions: Optional[
|
||||
List[int]] = [] if not self.use_mrope else None
|
||||
self.seq_lens: List[int] = []
|
||||
self.query_lens: List[int] = []
|
||||
self.prefill_block_tables: List[List[int]] = []
|
||||
self.decode_block_tables: List[List[int]] = []
|
||||
self.max_decode_seq_len: int = 0
|
||||
self.num_prefills: int = 0
|
||||
self.num_prefill_tokens: int = 0
|
||||
self.num_decode_tokens: int = 0
|
||||
self.slot_mapping: List[int] = []
|
||||
self.multi_modal_inputs_list: List[MultiModalKwargs] = []
|
||||
self.multi_modal_placeholder_maps: Dict[
|
||||
str, MultiModalPlaceholderMap] = defaultdict(
|
||||
MultiModalPlaceholderMap)
|
||||
self.input_mrope_positions: Optional[List[List[int]]] = [
|
||||
[] for _ in range(3)
|
||||
] if self.use_mrope else None
|
||||
|
||||
def __init__(self,
|
||||
runner: "CPUModelRunner",
|
||||
finished_requests_ids: Optional[List[str]] = None) -> None:
|
||||
super().__init__()
|
||||
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
self.runner = runner
|
||||
|
||||
self.chunked_prefill = (runner.scheduler_config.chunked_prefill_enabled
|
||||
or runner.cache_config.enable_prefix_caching)
|
||||
self.model_input_cls = self.runner._model_input_cls
|
||||
self.attn_backend = self.runner.attn_backend
|
||||
self.sliding_window = self.runner.sliding_window
|
||||
self.block_size = self.runner.block_size
|
||||
self.device = self.runner.device
|
||||
self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
|
||||
self.input_data = ModelInputForCPUBuilder.ModelInputData(
|
||||
self.runner.model_config.uses_mrope)
|
||||
self.att_metadata_builder = self.runner.attn_backend.get_builder_cls()(
|
||||
self)
|
||||
|
||||
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
|
||||
self.seq_group_metadata_list.append(seq_group_metadata)
|
||||
|
||||
def set_seq_group_list(
|
||||
self, seq_group_metadata_list: List[SequenceGroupMetadata]):
|
||||
self.seq_group_metadata_list = seq_group_metadata_list
|
||||
|
||||
def build(self) -> ModelInputForCPU:
|
||||
self._build_input_data()
|
||||
|
||||
input_data = self.input_data
|
||||
input_tokens = torch.tensor(input_data.input_tokens,
|
||||
dtype=torch.long,
|
||||
device="cpu")
|
||||
input_positions = torch.tensor(
|
||||
input_data.input_positions
|
||||
if not input_data.use_mrope else input_data.input_mrope_positions,
|
||||
dtype=torch.long,
|
||||
device="cpu")
|
||||
|
||||
# For multi-modal models
|
||||
multi_modal_kwargs = None
|
||||
# NOTE: We assume that all sequences in the group are all prompts or
|
||||
# all decodes.
|
||||
is_prompt = self.seq_group_metadata_list[0].is_prompt
|
||||
# Prepare input tensors.
|
||||
if is_prompt:
|
||||
(input_tokens, input_positions, attn_metadata, seq_lens,
|
||||
multi_modal_kwargs) = self._prepare_prompt(
|
||||
self.seq_group_metadata_list)
|
||||
else:
|
||||
(input_tokens, input_positions,
|
||||
attn_metadata) = self._prepare_decode(
|
||||
self.seq_group_metadata_list)
|
||||
seq_lens = None
|
||||
if len(input_data.multi_modal_inputs_list) != 0:
|
||||
multi_modal_kwargs = MultiModalKwargs.batch(
|
||||
input_data.multi_modal_inputs_list)
|
||||
|
||||
attn_metadata = self.att_metadata_builder.build(
|
||||
input_data.seq_lens, input_data.query_lens, -1, -1)
|
||||
|
||||
return self.model_input_cls(
|
||||
input_tokens=input_tokens,
|
||||
input_positions=input_positions,
|
||||
seq_lens=input_data.seq_lens,
|
||||
query_lens=input_data.query_lens,
|
||||
attn_metadata=attn_metadata,
|
||||
multi_modal_kwargs=multi_modal_kwargs,
|
||||
# query_lens is not needed if chunked prefill is not
|
||||
# supported. Since CPU worker doesn't support chunked prefill
|
||||
# just use seq_lens instead.
|
||||
seq_lens=seq_lens,
|
||||
query_lens=seq_lens,
|
||||
)
|
||||
|
||||
def _compute_multi_modal_input(
|
||||
self,
|
||||
seq_data: SequenceData,
|
||||
computed_len: int,
|
||||
def _build_input_data(self):
|
||||
for seq_group_metadata in self.seq_group_metadata_list:
|
||||
for seq_id, seq_data in seq_group_metadata.seq_data.items():
|
||||
if seq_group_metadata.is_prompt:
|
||||
self._compute_prompt_input_tokens(self.input_data,
|
||||
seq_group_metadata,
|
||||
seq_data, seq_id)
|
||||
if seq_group_metadata.multi_modal_data:
|
||||
self._compute_multi_modal_input(
|
||||
seq_group_metadata, seq_data)
|
||||
else:
|
||||
self._compute_decode_input_tokens(self.input_data,
|
||||
seq_group_metadata,
|
||||
seq_data, seq_id)
|
||||
|
||||
def _compute_decode_input_tokens(self, data: ModelInputData,
|
||||
seq_group_metadata: SequenceGroupMetadata,
|
||||
):
|
||||
seq_data: SequenceData, seq_id: int):
|
||||
"""
|
||||
Compute decode input tokens, positions, block table and slot mapping.
|
||||
"""
|
||||
block_size = self.runner.block_size
|
||||
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
seq_len = seq_data.get_len()
|
||||
context_len = seq_data.get_num_computed_tokens()
|
||||
|
||||
tokens = seq_data.get_last_token_id()
|
||||
token_positions = seq_len - 1
|
||||
block_number = block_table[token_positions // block_size]
|
||||
block_offset = token_positions % block_size
|
||||
slot = block_number * block_size + block_offset
|
||||
|
||||
# For paged_attention kernel
|
||||
if self.runner.sliding_window:
|
||||
start_idx = max(0, seq_len - self.runner.sliding_window)
|
||||
start_block = start_idx // block_size
|
||||
start_idx = start_block * block_size
|
||||
seq_len = seq_len - start_idx
|
||||
block_table = block_table[start_block:]
|
||||
|
||||
# For MRotaryEmbedding
|
||||
if data.input_positions is None:
|
||||
next_pos = MRotaryEmbedding.get_next_input_positions(
|
||||
seq_data.mrope_position_delta,
|
||||
context_len,
|
||||
seq_len,
|
||||
)
|
||||
for idx in range(3):
|
||||
data.input_mrope_positions[idx].extend( # type: ignore
|
||||
next_pos[idx])
|
||||
else:
|
||||
data.input_positions.append(token_positions) # type: ignore
|
||||
|
||||
# Update fields
|
||||
data.input_tokens.append(tokens)
|
||||
data.max_decode_seq_len = max(data.max_decode_seq_len, seq_len)
|
||||
data.num_decode_tokens += 1
|
||||
data.slot_mapping.append(slot)
|
||||
data.decode_block_tables.append(block_table)
|
||||
data.query_lens.append(1)
|
||||
data.seq_lens.append(seq_len)
|
||||
|
||||
def _compute_prompt_input_tokens(self, data: ModelInputData,
|
||||
seq_group_metadata: SequenceGroupMetadata,
|
||||
seq_data: SequenceData, seq_id: int):
|
||||
"""
|
||||
Compute prompt input tokens, positions, block table and slot mapping.
|
||||
"""
|
||||
token_chunk_size = seq_group_metadata.token_chunk_size
|
||||
block_size = self.runner.block_size
|
||||
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
seq_len = seq_data.get_len()
|
||||
context_len = seq_data.get_num_computed_tokens()
|
||||
seq_len = min(seq_len, context_len + token_chunk_size)
|
||||
|
||||
# For prefix caching
|
||||
prefix_cache_block_num = len(seq_group_metadata.computed_block_nums)
|
||||
if prefix_cache_block_num > 0:
|
||||
prefix_cache_len = (prefix_cache_block_num *
|
||||
self.runner.block_size)
|
||||
if prefix_cache_len <= context_len:
|
||||
# We already passed the cache hit region,
|
||||
# so do normal computation.
|
||||
pass
|
||||
elif context_len < prefix_cache_len < seq_len:
|
||||
# Partial hit. Compute the missing part.
|
||||
context_len = prefix_cache_len
|
||||
token_chunk_size = seq_len - context_len
|
||||
elif seq_len <= prefix_cache_len:
|
||||
# Full hit. Only compute the last token to avoid
|
||||
# erroneous behavior. FIXME: Ideally we should directly
|
||||
# mark all tokens as computed in the scheduler and do not
|
||||
# schedule this sequence, so this case should not happen.
|
||||
context_len = seq_len - 1
|
||||
token_chunk_size = 1
|
||||
|
||||
tokens = seq_data.get_token_ids()
|
||||
tokens = tokens[context_len:seq_len]
|
||||
token_positions = range(context_len, seq_len)
|
||||
|
||||
# For encoder-only models, the block_table is None,
|
||||
# and there is no need to initialize the slot_mapping.
|
||||
if block_table is not None:
|
||||
slot_mapping = [_PAD_SLOT_ID] * len(token_positions)
|
||||
for i, pos in enumerate(token_positions):
|
||||
block_number = block_table[pos // block_size]
|
||||
block_offset = pos % block_size
|
||||
slot = block_number * block_size + block_offset
|
||||
slot_mapping[i] = slot
|
||||
data.slot_mapping.extend(slot_mapping)
|
||||
|
||||
# The MROPE positions are prepared in _compute_multi_modal_input
|
||||
if data.input_positions is not None:
|
||||
data.input_positions.extend(token_positions)
|
||||
|
||||
# Update fields
|
||||
data.input_tokens.extend(tokens)
|
||||
data.num_prefills += 1
|
||||
data.num_prefill_tokens += len(tokens)
|
||||
data.query_lens.append(len(tokens))
|
||||
data.prefill_block_tables.append(block_table)
|
||||
data.seq_lens.append(seq_len)
|
||||
|
||||
def _compute_multi_modal_input(self,
|
||||
seq_group_metadata: SequenceGroupMetadata,
|
||||
seq_data: SequenceData):
|
||||
computed_len = seq_data.get_num_computed_tokens()
|
||||
seq_len = self.input_data.seq_lens[-1]
|
||||
|
||||
# NOTE: mm_data only includes the subset of multi-modal items that
|
||||
# intersect with the current prefill positions.
|
||||
mm_data, placeholder_maps = MultiModalPlaceholderMap.from_seq_group(
|
||||
seq_group_metadata,
|
||||
range(computed_len, len(seq_data.get_token_ids())),
|
||||
)
|
||||
seq_group_metadata, range(computed_len, seq_len))
|
||||
|
||||
if not mm_data:
|
||||
return None, None, None
|
||||
return
|
||||
|
||||
if self.runner.mm_registry.has_processor(self.runner.model_config):
|
||||
mm_kwargs = mm_data
|
||||
@ -173,8 +330,10 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
||||
)
|
||||
|
||||
# special processing for mrope position deltas.
|
||||
mrope_positions = None
|
||||
if self.runner.model_config.uses_mrope:
|
||||
assert not self.chunked_prefill, \
|
||||
"MROPE on CPU does not support chunked-prefill."
|
||||
|
||||
image_grid_thw = mm_kwargs.get("image_grid_thw", None)
|
||||
video_grid_thw = mm_kwargs.get("video_grid_thw", None)
|
||||
assert image_grid_thw is not None or video_grid_thw is not None, (
|
||||
@ -198,227 +357,16 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
||||
context_len=computed_len,
|
||||
)
|
||||
seq_data.mrope_position_delta = mrope_position_delta
|
||||
return mm_kwargs, placeholder_maps, mrope_positions
|
||||
|
||||
def _prepare_prompt(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
|
||||
BatchedTensorInputs]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[int] = []
|
||||
input_positions: List[int] = []
|
||||
input_mrope_positions: List[List[int]] = [[] for _ in range(3)]
|
||||
for i in range(3):
|
||||
self.input_data.input_mrope_positions[ # type: ignore
|
||||
i].extend(mrope_positions[i])
|
||||
|
||||
slot_mapping: List[int] = []
|
||||
seq_lens: List[int] = []
|
||||
multi_modal_kwargs_list: List[MultiModalKwargs] = []
|
||||
multi_modal_placeholder_maps: Dict[
|
||||
str,
|
||||
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert seq_group_metadata.is_prompt
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
assert len(seq_ids) == 1
|
||||
seq_id = seq_ids[0]
|
||||
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
prompt_tokens = seq_data.get_token_ids()
|
||||
computed_len = seq_data.get_num_computed_tokens()
|
||||
seq_len = len(prompt_tokens)
|
||||
|
||||
seq_lens.append(seq_len) # Prompt token num
|
||||
input_tokens.extend(prompt_tokens) # Token ids
|
||||
|
||||
mrope_positions = None
|
||||
if seq_group_metadata.multi_modal_data:
|
||||
(
|
||||
mm_kwargs,
|
||||
placeholder_maps,
|
||||
mrope_positions,
|
||||
) = self._compute_multi_modal_input(seq_data, computed_len,
|
||||
seq_group_metadata)
|
||||
|
||||
multi_modal_kwargs_list.append(mm_kwargs)
|
||||
self.input_data.multi_modal_inputs_list.append(mm_kwargs)
|
||||
for modality, placeholder_map in placeholder_maps.items():
|
||||
multi_modal_placeholder_maps[modality].extend(
|
||||
self.input_data.multi_modal_placeholder_maps[modality].extend(
|
||||
placeholder_map)
|
||||
|
||||
# Token position ids
|
||||
# NOTE(woosuk): Here we assume that the first token in the prompt
|
||||
# is always the first token in the sequence.
|
||||
if mrope_positions:
|
||||
for idx in range(3):
|
||||
input_mrope_positions[idx].extend(mrope_positions[idx])
|
||||
else:
|
||||
input_positions.extend(list(range(computed_len, seq_len)))
|
||||
|
||||
# Compute the slot mapping.
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
|
||||
# where start_idx is max(0, seq_len - sliding_window).
|
||||
# For example, if the prompt len is 10, sliding window is 8, and
|
||||
# block size is 4, the first two tokens are masked and the slot
|
||||
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
|
||||
start_idx = 0
|
||||
if self.sliding_window is not None:
|
||||
start_idx = max(0, seq_len - self.sliding_window)
|
||||
|
||||
for i in range(computed_len, seq_len):
|
||||
if i < start_idx:
|
||||
slot_mapping.append(_PAD_SLOT_ID)
|
||||
continue
|
||||
|
||||
# For encoder-only models, the block_table is None,
|
||||
# and there is no need to initialize the slot_mapping.
|
||||
if block_table is not None:
|
||||
block_number = block_table[i //
|
||||
self.block_size] # type: ignore
|
||||
block_offset = i % self.block_size # type: ignore
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append(slot)
|
||||
|
||||
if any(input_mrope_positions):
|
||||
input_positions = None # type: ignore
|
||||
else:
|
||||
input_mrope_positions = None # type: ignore
|
||||
|
||||
num_prompt_tokens = len(input_tokens)
|
||||
|
||||
input_tokens = torch.tensor(input_tokens,
|
||||
dtype=torch.long,
|
||||
device=self.device) # type: ignore
|
||||
input_positions = torch.tensor(input_positions
|
||||
or input_mrope_positions,
|
||||
dtype=torch.long,
|
||||
device=self.device) # type: ignore
|
||||
slot_mapping = torch.tensor(slot_mapping,
|
||||
dtype=torch.long,
|
||||
device=self.device) # type: ignore
|
||||
placeholder_index_maps = {
|
||||
modality: placeholder_map.index_map()
|
||||
for modality, placeholder_map in
|
||||
multi_modal_placeholder_maps.items()
|
||||
}
|
||||
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
is_prompt=True,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=torch.tensor([]),
|
||||
max_decode_seq_len=0,
|
||||
num_prefills=len(seq_lens),
|
||||
num_prefill_tokens=num_prompt_tokens,
|
||||
num_decode_tokens=0,
|
||||
block_tables=torch.tensor([]),
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
)
|
||||
|
||||
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
|
||||
|
||||
return (input_tokens, input_positions, attn_metadata, seq_lens,
|
||||
multi_modal_kwargs)
|
||||
|
||||
def _prepare_decode(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[int] = []
|
||||
input_positions: List[int] = []
|
||||
input_mrope_positions: List[List[int]] = [[] for _ in range(3)]
|
||||
slot_mapping: List[int] = []
|
||||
seq_lens: List[int] = []
|
||||
block_tables: List[List[int]] = []
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert not seq_group_metadata.is_prompt
|
||||
assert seq_group_metadata.token_chunk_size == 1
|
||||
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
|
||||
for seq_id in seq_ids:
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
generation_token = seq_data.get_last_token_id()
|
||||
input_tokens.append(generation_token)
|
||||
|
||||
seq_len = seq_data.get_len()
|
||||
position = seq_len - 1
|
||||
if seq_data.mrope_position_delta is not None:
|
||||
context_len = seq_data.get_num_computed_tokens()
|
||||
next_pos = MRotaryEmbedding.get_next_input_positions(
|
||||
seq_data.mrope_position_delta,
|
||||
context_len,
|
||||
seq_len,
|
||||
)
|
||||
for idx in range(3):
|
||||
input_mrope_positions[idx].extend(next_pos[idx])
|
||||
else:
|
||||
input_positions.append(position)
|
||||
|
||||
seq_len = seq_len if self.sliding_window is None else min(
|
||||
seq_len, self.sliding_window)
|
||||
seq_lens.append(seq_len)
|
||||
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
block_number = block_table[position // self.block_size]
|
||||
block_offset = position % self.block_size
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append(slot)
|
||||
|
||||
if self.sliding_window is not None:
|
||||
sliding_window_blocks = (self.sliding_window //
|
||||
self.block_size)
|
||||
block_table = block_table[-sliding_window_blocks:]
|
||||
block_tables.append(block_table)
|
||||
|
||||
if any(input_mrope_positions):
|
||||
input_positions = None # type: ignore
|
||||
else:
|
||||
input_mrope_positions = None # type: ignore
|
||||
|
||||
max_decode_seq_len = max(seq_lens)
|
||||
|
||||
input_tokens = torch.tensor(input_tokens,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
input_positions = torch.tensor(input_positions
|
||||
or input_mrope_positions,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
slot_mapping = torch.tensor(slot_mapping,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
seq_lens_tensor = torch.tensor(seq_lens,
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
|
||||
block_tables = make_tensor_with_pad(
|
||||
block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
is_prompt=False,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=len(input_tokens),
|
||||
num_prefills=0,
|
||||
block_tables=block_tables,
|
||||
)
|
||||
return (
|
||||
input_tokens,
|
||||
input_positions,
|
||||
attn_metadata,
|
||||
)
|
||||
|
||||
|
||||
class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
|
||||
"""
|
||||
@ -436,8 +384,6 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
|
||||
**kwargs,
|
||||
):
|
||||
ModelRunnerBase.__init__(self, vllm_config)
|
||||
# Currently, CPU worker doesn't support chunked prefill.
|
||||
assert self.scheduler_config.chunked_prefill_enabled is False
|
||||
model_config = self.model_config
|
||||
cache_config = self.cache_config
|
||||
|
||||
@ -479,8 +425,7 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
|
||||
|
||||
"""
|
||||
builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
builder.add_seq_group(seq_group_metadata)
|
||||
builder.set_seq_group_list(seq_group_metadata_list)
|
||||
|
||||
return builder.build() # type: ignore
|
||||
|
||||
@ -537,22 +482,19 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
|
||||
"CPU worker does not support multi-step execution.")
|
||||
|
||||
model_executable = self.model
|
||||
execute_model_kwargs = {
|
||||
"input_ids":
|
||||
model_input.input_tokens,
|
||||
"positions":
|
||||
model_input.input_positions,
|
||||
"kv_caches":
|
||||
kv_caches,
|
||||
"attn_metadata":
|
||||
model_input.attn_metadata,
|
||||
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
|
||||
device=self.device),
|
||||
"intermediate_tensors":
|
||||
intermediate_tensors,
|
||||
}
|
||||
multimodal_kwargs = {}
|
||||
if model_input.multi_modal_kwargs is not None:
|
||||
multimodal_kwargs = MultiModalKwargs.as_kwargs(
|
||||
model_input.multi_modal_kwargs, device=self.device)
|
||||
|
||||
hidden_states = model_executable(**execute_model_kwargs)
|
||||
hidden_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=model_input.attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**multimodal_kwargs,
|
||||
)
|
||||
|
||||
# Compute the logits.
|
||||
logits = self.model.compute_logits(hidden_states,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user