add mixin

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang 2025-09-17 22:02:15 -07:00
parent 4ac510f484
commit 808fa43d76
2 changed files with 386 additions and 323 deletions

View File

@ -7,7 +7,6 @@ import time
from collections import defaultdict
from collections.abc import Iterator
from contextlib import contextmanager
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Optional, Union, cast
import numpy as np
@ -27,9 +26,7 @@ from vllm.compilation.monitor import set_cudagraph_capturing_enabled
from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig,
get_layers_from_vllm_config, update_config)
from vllm.distributed.eplb.eplb_state import EplbState
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
from vllm.distributed.kv_transfer import has_kv_transfer_group
from vllm.distributed.parallel_state import (
get_pp_group, get_tp_group, graph_capture, is_global_first_rank,
prepare_communication_buffer_for_model)
@ -54,7 +51,7 @@ from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size,
GiB_bytes, LazyLoader, check_use_alibi,
is_pin_memory_available, round_up, supports_dynamo)
from vllm.v1.attention.backends.flash_attn import AttentionMetadata
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
@ -70,8 +67,8 @@ from vllm.v1.kv_cache_interface import (AttentionSpec,
CrossAttentionSpec,
EncoderOnlyAttentionSpec,
FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheSpec,
MambaSpec, SlidingWindowSpec)
KVCacheSpec, MambaSpec,
SlidingWindowSpec)
# yapf: enable
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
DraftTokenIds, LogprobsLists, LogprobsTensors,
@ -88,6 +85,7 @@ from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
from vllm.v1.worker.kv_cache_initializer_mixin import KVCacheInitializerMixin
from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorModelRunnerMixin)
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
@ -95,10 +93,8 @@ from vllm.v1.worker.ubatch_splitting import get_dp_padding_ubatch, ubatch_split
from vllm.v1.worker.ubatch_utils import UBatchSlice, UBatchSlices
from vllm.v1.worker.utils import is_residual_scattered_for_sp
from .utils import (AttentionGroup, MultiModalBudget,
add_kv_sharing_layers_to_kv_cache_groups, bind_kv_cache,
gather_mm_placeholders, sanity_check_mm_encoder_outputs,
scatter_mm_placeholders)
from .utils import (AttentionGroup, MultiModalBudget, gather_mm_placeholders,
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
if TYPE_CHECKING:
import xgrammar as xgr
@ -163,7 +159,8 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
return output
class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
class GPUModelRunner(KVCacheInitializerMixin, LoRAModelRunnerMixin,
KVConnectorModelRunnerMixin):
def __init__(
self,
@ -255,7 +252,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.kv_caches: list[torch.Tensor] = []
# indexes: [kv_cache_group_id][attn_group]
self.attn_groups: list[list[AttentionGroup]] = []
# self.kv_cache_config: KVCacheConfig
# a fake value to satisfy the type checker
self.kv_cache_config: KVCacheConfig = cast(KVCacheConfig, None)
# mm_hash -> encoder_output
self.encoder_cache: dict[str, torch.Tensor] = {}
@ -3529,319 +3527,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
else:
self.reorder_batch_threshold = reorder_batch_threshold_i
def may_reinitialize_input_batch(self,
kv_cache_config: KVCacheConfig) -> None:
"""
Re-initialize the input batch if the block sizes are different from
`[self.cache_config.block_size]`. This usually happens when there
are multiple KV cache groups.
Args:
kv_cache_config: The KV cache configuration.
"""
block_sizes = [
kv_cache_group.kv_cache_spec.block_size
for kv_cache_group in kv_cache_config.kv_cache_groups
]
if block_sizes != [self.cache_config.block_size]:
assert self.cache_config.cpu_offload_gb == 0, (
"Cannot re-initialize the input batch when CPU weight "
"offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501
"for more details.")
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=max(self.max_model_len, self.max_encoder_len),
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
block_sizes=block_sizes,
is_spec_decode=bool(self.vllm_config.speculative_config),
logitsprocs=self.input_batch.logitsprocs,
is_pooling_model=self.is_pooling_model,
num_speculative_tokens=(
self.vllm_config.speculative_config.num_speculative_tokens
if self.vllm_config.speculative_config else 0),
)
def _allocate_kv_cache_tensors(
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
"""
Initializes the KV cache buffer with the correct size. The buffer needs
to be reshaped to the desired shape before being used by the models.
Args:
kv_cache_config: The KV cache config
Returns:
dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache.
"""
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
tensor = torch.zeros(kv_cache_tensor.size,
dtype=torch.int8,
device=self.device)
for layer_name in kv_cache_tensor.shared_by:
kv_cache_raw_tensors[layer_name] = tensor
layer_names = set()
for group in kv_cache_config.kv_cache_groups:
for layer_name in group.layer_names:
if layer_name in self.runner_only_attn_layers:
continue
layer_names.add(layer_name)
assert layer_names == set(kv_cache_raw_tensors.keys(
)), "Some layers are not correctly initialized"
return kv_cache_raw_tensors
def _attn_group_iterator(self) -> Iterator[AttentionGroup]:
return itertools.chain.from_iterable(self.attn_groups)
def _kv_cache_spec_attn_group_iterator(
self) -> Iterator[tuple[KVCacheSpec, AttentionGroup]]:
if not self.kv_cache_config.kv_cache_groups:
return
for kv_cache_spec_id, attn_groups in enumerate(self.attn_groups):
for attn_group in attn_groups:
yield self.kv_cache_config.kv_cache_groups[
kv_cache_spec_id].kv_cache_spec, attn_group
def _reshape_kv_cache_tensors(
self,
kv_cache_config: KVCacheConfig,
kv_cache_raw_tensors: dict[str, torch.Tensor],
) -> dict[str, torch.Tensor]:
"""
Reshape the KV cache tensors to the desired shape and dtype.
Args:
kv_cache_config: The KV cache config
kv_cache_raw_tensors: The KV cache buffer of each layer, with
correct size but uninitialized shape.
Returns:
Dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache.
"""
kv_caches: dict[str, torch.Tensor] = {}
has_attn, has_mamba = False, False
for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator():
attn_backend = group.backend
for layer_name in group.layer_names:
if layer_name in self.runner_only_attn_layers:
continue
raw_tensor = kv_cache_raw_tensors[layer_name]
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
num_blocks = (raw_tensor.numel() //
kv_cache_spec.page_size_bytes)
if isinstance(kv_cache_spec, AttentionSpec):
has_attn = True
kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype
try:
kv_cache_stride_order = \
attn_backend.get_kv_cache_stride_order()
assert len(kv_cache_stride_order) == len(
kv_cache_shape)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(
range(len(kv_cache_shape)))
# The allocation respects the backend-defined stride order
# to ensure the semantic remains consistent for each
# backend. We first obtain the generic kv cache shape and
# then permute it according to the stride order which could
# result in a non-contiguous tensor.
kv_cache_shape = tuple(kv_cache_shape[i]
for i in kv_cache_stride_order)
# Maintain original KV shape view.
inv_order = [
kv_cache_stride_order.index(i)
for i in range(len(kv_cache_stride_order))
]
kv_caches[layer_name] = kv_cache_raw_tensors[
layer_name].view(dtype).view(kv_cache_shape).permute(
*inv_order)
elif isinstance(kv_cache_spec, MambaSpec):
has_mamba = True
raw_tensor = kv_cache_raw_tensors[layer_name]
state_tensors = []
storage_offset_bytes = 0
for (shape, dtype) in zip(kv_cache_spec.shapes,
kv_cache_spec.dtypes):
dtype_size = get_dtype_size(dtype)
num_element_per_page = (
kv_cache_spec.page_size_bytes // dtype_size)
target_shape = (num_blocks, *shape)
stride = torch.empty(target_shape).stride()
target_stride = (num_element_per_page, *stride[1:])
assert storage_offset_bytes % dtype_size == 0
tensor = torch.as_strided(
raw_tensor.view(dtype),
size=target_shape,
stride=target_stride,
storage_offset=storage_offset_bytes // dtype_size,
)
state_tensors.append(tensor)
storage_offset_bytes += stride[0] * dtype_size
kv_caches[layer_name] = state_tensors
else:
raise NotImplementedError
if has_attn and has_mamba:
self._update_hybrid_attention_mamba_layout(kv_caches)
return kv_caches
def _update_hybrid_attention_mamba_layout(
self, kv_caches: dict[str, torch.Tensor]) -> None:
"""
Update the layout of attention layers from (2, num_blocks, ...) to
(num_blocks, 2, ...).
Args:
kv_caches: The KV cache buffer of each layer.
"""
for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator():
for layer_name in group.layer_names:
kv_cache = kv_caches[layer_name]
if (isinstance(kv_cache_spec, AttentionSpec)
and kv_cache.shape[0] == 2):
assert kv_cache.shape[1] != 2, \
"Fail to determine whether the layout is " \
"(2, num_blocks, ...) or (num_blocks, 2, ...) for " \
f"a tensor of shape {kv_cache.shape}"
hidden_size = kv_cache.shape[2:].numel()
kv_cache.as_strided_(size=kv_cache.shape,
stride=(hidden_size, 2 * hidden_size,
*kv_cache.stride()[2:]))
def initialize_kv_cache_tensors(
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
"""
Initialize the memory buffer for KV cache.
Args:
kv_cache_config: The KV cache config
Returns:
Dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache.
"""
# Initialize the memory buffer for KV cache
kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config)
# Change the memory buffer to the desired shape
kv_caches = self._reshape_kv_cache_tensors(kv_cache_config,
kv_cache_raw_tensors)
# Set up cross-layer KV cache sharing
for layer_name, target_layer_name in self.shared_kv_cache_layers.items(
):
logger.debug("%s reuses KV cache of %s", layer_name,
target_layer_name)
kv_caches[layer_name] = kv_caches[target_layer_name]
bind_kv_cache(kv_caches,
self.compilation_config.static_forward_context,
self.kv_caches)
return kv_caches
def maybe_add_kv_sharing_layers_to_kv_cache_groups(
self, kv_cache_config: KVCacheConfig) -> None:
"""
Add layers that re-use KV cache to KV cache group of its target layer.
Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()`
"""
if not self.shared_kv_cache_layers:
# No cross-layer KV sharing, return
return
add_kv_sharing_layers_to_kv_cache_groups(
self.shared_kv_cache_layers,
kv_cache_config.kv_cache_groups,
self.runner_only_attn_layers,
)
if self.cache_config.kv_sharing_fast_prefill:
# In You Only Cache Once (https://arxiv.org/abs/2405.05254) or other
# similar KV sharing setups, only the layers that generate KV caches
# are involved in the prefill phase, enabling prefill to early exit.
attn_layers = get_layers_from_vllm_config(self.vllm_config,
Attention)
for layer_name in reversed(attn_layers):
if layer_name in self.shared_kv_cache_layers:
self.kv_sharing_fast_prefill_eligible_layers.add(
layer_name)
else:
break
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize KV cache based on `kv_cache_config`.
Args:
kv_cache_config: Configuration for the KV cache, including the KV
cache size of each layer
"""
kv_cache_config = deepcopy(kv_cache_config)
self.kv_cache_config = kv_cache_config
self.may_reinitialize_input_batch(kv_cache_config)
self.may_add_encoder_only_layers_to_kv_cache_config()
self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
self.initialize_attn_backend(kv_cache_config)
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
if self.speculative_config and self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)
# validate all draft model layers belong to the same kv cache
# group
self.drafter.validate_same_kv_cache_group(kv_cache_config)
if has_kv_transfer_group():
get_kv_transfer_group().register_kv_caches(kv_caches)
if self.device.type == 'xpu':
get_kv_transfer_group().set_host_xfer_buffer_ops(
copy_kv_blocks)
if self.dcp_world_size > 1:
layer_names = self.attn_groups[0][0].layer_names
layers = get_layers_from_vllm_config(self.vllm_config,
AttentionLayerBase,
layer_names)
for layer in layers.values():
assert layer.impl.need_to_return_lse_for_decode, (
"DCP requires attention impls to return"
" the softmax lse for decode, but the impl "
f"{layer.impl.__class__.__name__} "
"does not return the softmax lse for decode.")
def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
"""
Add encoder-only layers to the KV cache config.
"""
block_size = self.vllm_config.cache_config.block_size
use_mla = self.vllm_config.model_config.use_mla
encoder_only_attn_specs: dict[AttentionSpec,
list[str]] = defaultdict(list)
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
for layer_name, attn_module in attn_layers.items():
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
attn_spec: AttentionSpec = EncoderOnlyAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
use_mla=use_mla)
encoder_only_attn_specs[attn_spec].append(layer_name)
self.runner_only_attn_layers.add(layer_name)
if len(encoder_only_attn_specs) > 0:
assert len(
encoder_only_attn_specs
) == 1, "Only support one encoder-only attention spec now"
spec, layer_names = encoder_only_attn_specs.popitem()
self.kv_cache_config.kv_cache_groups.append(
KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec))
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
"""
Generates the KVCacheSpec by parsing the kv cache format from each

View File

@ -0,0 +1,375 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from collections.abc import Iterator
from copy import deepcopy
from typing import Any, Protocol, cast
import torch
from vllm.attention import Attention, AttentionType
from vllm.config import get_layers_from_vllm_config
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.utils import get_dtype_size
from vllm.v1.kv_cache_interface import (AttentionSpec,
EncoderOnlyAttentionSpec,
KVCacheConfig, KVCacheGroupSpec,
KVCacheSpec, MambaSpec)
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.worker.gpu_input_batch import InputBatch
from .utils import (AttentionGroup, add_kv_sharing_layers_to_kv_cache_groups,
bind_kv_cache)
class _KVCacheInitializerSelf(Protocol):
cache_config: Any
max_num_reqs: int
max_model_len: int
max_encoder_len: int
max_num_tokens: int
device: Any
pin_memory: bool
model_config: Any
vllm_config: Any
input_batch: InputBatch
is_pooling_model: bool
shared_kv_cache_layers: dict[str, str]
kv_sharing_fast_prefill_eligible_layers: set[str]
runner_only_attn_layers: set[str]
kv_cache_dtype: torch.dtype
kv_cache_config: KVCacheConfig
compilation_config: Any
kv_caches: Any
speculative_config: Any
drafter: Any
dcp_world_size: int
attn_groups: list[list[AttentionGroup]]
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
...
logger = init_logger(__name__)
# Defined as a mixin for GPUModelRunner
class KVCacheInitializerMixin:
def _runner(self) -> _KVCacheInitializerSelf:
return cast(_KVCacheInitializerSelf, self)
def may_reinitialize_input_batch(self,
kv_cache_config: KVCacheConfig) -> None:
"""
Re-initialize the input batch if the block sizes are different from
`[self.cache_config.block_size]`. This usually happens when there
are multiple KV cache groups.
Args:
kv_cache_config: The KV cache configuration.
"""
runner = self._runner()
block_sizes = [
kv_cache_group.kv_cache_spec.block_size
for kv_cache_group in kv_cache_config.kv_cache_groups
]
if block_sizes != [runner.cache_config.block_size]:
assert runner.cache_config.cpu_offload_gb == 0, (
"Cannot re-initialize the input batch when CPU weight "
"offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501
"for more details.")
runner.input_batch = InputBatch(
max_num_reqs=runner.max_num_reqs,
max_model_len=max(runner.max_model_len,
runner.max_encoder_len),
max_num_batched_tokens=runner.max_num_tokens,
device=runner.device,
pin_memory=runner.pin_memory,
vocab_size=runner.model_config.get_vocab_size(),
block_sizes=block_sizes,
is_spec_decode=bool(runner.vllm_config.speculative_config),
logitsprocs=runner.input_batch.logitsprocs,
is_pooling_model=runner.is_pooling_model,
num_speculative_tokens=(runner.vllm_config.speculative_config.
num_speculative_tokens if
runner.vllm_config.speculative_config
else 0),
)
def _allocate_kv_cache_tensors(
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
"""
Initializes the KV cache buffer with the correct size. The buffer needs
to be reshaped to the desired shape before being used by the models.
Args:
kv_cache_config: The KV cache config
Returns:
dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache.
"""
runner = self._runner()
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
tensor = torch.zeros(kv_cache_tensor.size,
dtype=torch.int8,
device=runner.device)
for layer_name in kv_cache_tensor.shared_by:
kv_cache_raw_tensors[layer_name] = tensor
layer_names = set()
for group in kv_cache_config.kv_cache_groups:
for layer_name in group.layer_names:
if layer_name in runner.runner_only_attn_layers:
continue
layer_names.add(layer_name)
assert layer_names == set(kv_cache_raw_tensors.keys(
)), "Some layers are not correctly initialized"
return kv_cache_raw_tensors
def _kv_cache_spec_attn_group_iterator(
self) -> Iterator[tuple[KVCacheSpec, AttentionGroup]]:
runner = self._runner()
if not runner.kv_cache_config.kv_cache_groups:
return
for kv_cache_spec_id, attn_groups in enumerate(runner.attn_groups):
for attn_group in attn_groups:
yield runner.kv_cache_config.kv_cache_groups[
kv_cache_spec_id].kv_cache_spec, attn_group
def _reshape_kv_cache_tensors(
self,
kv_cache_config: KVCacheConfig,
kv_cache_raw_tensors: dict[str, torch.Tensor],
) -> dict[str, torch.Tensor]:
"""
Reshape the KV cache tensors to the desired shape and dtype.
Args:
kv_cache_config: The KV cache config
kv_cache_raw_tensors: The KV cache buffer of each layer, with
correct size but uninitialized shape.
Returns:
Dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache.
"""
runner = self._runner()
kv_caches: dict[str, torch.Tensor] = {}
has_attn, has_mamba = False, False
for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator():
attn_backend = group.backend
for layer_name in group.layer_names:
if layer_name in runner.runner_only_attn_layers:
continue
raw_tensor = kv_cache_raw_tensors[layer_name]
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
num_blocks = (raw_tensor.numel() //
kv_cache_spec.page_size_bytes)
if isinstance(kv_cache_spec, AttentionSpec):
has_attn = True
kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype
try:
kv_cache_stride_order = \
attn_backend.get_kv_cache_stride_order()
assert len(kv_cache_stride_order) == len(
kv_cache_shape)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(
range(len(kv_cache_shape)))
kv_cache_shape = tuple(kv_cache_shape[i]
for i in kv_cache_stride_order)
inv_order = [
kv_cache_stride_order.index(i)
for i in range(len(kv_cache_stride_order))
]
kv_caches[layer_name] = kv_cache_raw_tensors[
layer_name].view(dtype).view(kv_cache_shape).permute(
*inv_order)
elif isinstance(kv_cache_spec, MambaSpec):
has_mamba = True
raw_tensor = kv_cache_raw_tensors[layer_name]
state_tensors = []
storage_offset_bytes = 0
for (shape, dtype) in zip(kv_cache_spec.shapes,
kv_cache_spec.dtypes):
dtype_size = get_dtype_size(dtype)
num_element_per_page = (
kv_cache_spec.page_size_bytes // dtype_size)
target_shape = (num_blocks, *shape)
stride = torch.empty(target_shape).stride()
target_stride = (num_element_per_page, *stride[1:])
assert storage_offset_bytes % dtype_size == 0
tensor = torch.as_strided(
raw_tensor.view(dtype),
size=target_shape,
stride=target_stride,
storage_offset=storage_offset_bytes // dtype_size,
)
state_tensors.append(tensor)
storage_offset_bytes += stride[0] * dtype_size
kv_caches[layer_name] = state_tensors
else:
raise NotImplementedError
if has_attn and has_mamba:
self._update_hybrid_attention_mamba_layout(kv_caches)
return kv_caches
def _update_hybrid_attention_mamba_layout(
self, kv_caches: dict[str, torch.Tensor]) -> None:
"""
Update the layout of attention layers from (2, num_blocks, ...) to
(num_blocks, 2, ...).
Args:
kv_caches: The KV cache buffer of each layer.
"""
for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator():
for layer_name in group.layer_names:
kv_cache = kv_caches[layer_name]
if (isinstance(kv_cache_spec, AttentionSpec)
and kv_cache.shape[0] == 2):
assert kv_cache.shape[1] != 2, \
"Fail to determine whether the layout is " \
"(2, num_blocks, ...) or (num_blocks, 2, ...) for " \
f"a tensor of shape {kv_cache.shape}"
hidden_size = kv_cache.shape[2:].numel()
kv_cache.as_strided_(size=kv_cache.shape,
stride=(hidden_size, 2 * hidden_size,
*kv_cache.stride()[2:]))
def initialize_kv_cache_tensors(
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
"""
Initialize the memory buffer for KV cache.
Args:
kv_cache_config: The KV cache config
Returns:
Dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache.
"""
runner = self._runner()
kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config)
kv_caches = self._reshape_kv_cache_tensors(kv_cache_config,
kv_cache_raw_tensors)
for layer_name, target_layer_name in (
runner.shared_kv_cache_layers.items()):
logger.debug("%s reuses KV cache of %s", layer_name,
target_layer_name)
kv_caches[layer_name] = kv_caches[target_layer_name]
bind_kv_cache(kv_caches,
runner.compilation_config.static_forward_context,
runner.kv_caches)
return kv_caches
def maybe_add_kv_sharing_layers_to_kv_cache_groups(
self, kv_cache_config: KVCacheConfig) -> None:
"""
Add layers that re-use KV cache to KV cache group of its target layer.
Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()`
"""
runner = self._runner()
if not runner.shared_kv_cache_layers:
return
add_kv_sharing_layers_to_kv_cache_groups(
runner.shared_kv_cache_layers,
kv_cache_config.kv_cache_groups,
runner.runner_only_attn_layers,
)
if runner.cache_config.kv_sharing_fast_prefill:
attn_layers = get_layers_from_vllm_config(runner.vllm_config,
Attention)
for layer_name in reversed(attn_layers):
if layer_name in runner.shared_kv_cache_layers:
runner.kv_sharing_fast_prefill_eligible_layers.add(
layer_name)
else:
break
def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
"""
Add encoder-only layers to the KV cache config.
"""
runner = self._runner()
block_size = runner.vllm_config.cache_config.block_size
use_mla = runner.vllm_config.model_config.use_mla
encoder_only_attn_specs: dict[AttentionSpec,
list[str]] = defaultdict(list)
attn_layers = get_layers_from_vllm_config(runner.vllm_config,
Attention)
for layer_name, attn_module in attn_layers.items():
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
attn_spec: AttentionSpec = EncoderOnlyAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=runner.kv_cache_dtype,
use_mla=use_mla)
encoder_only_attn_specs[attn_spec].append(layer_name)
runner.runner_only_attn_layers.add(layer_name)
if len(encoder_only_attn_specs) > 0:
assert len(
encoder_only_attn_specs
) == 1, "Only support one encoder-only attention spec now"
spec, layer_names = encoder_only_attn_specs.popitem()
runner.kv_cache_config.kv_cache_groups.append(
KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec))
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize KV cache based on `kv_cache_config`.
Args:
kv_cache_config: Configuration for the KV cache, including the KV
cache size of each layer
"""
runner = self._runner()
kv_cache_config = deepcopy(kv_cache_config)
runner.kv_cache_config = kv_cache_config
self.may_reinitialize_input_batch(kv_cache_config)
self.may_add_encoder_only_layers_to_kv_cache_config()
self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
runner.initialize_attn_backend(kv_cache_config)
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
if runner.speculative_config and runner.speculative_config.use_eagle():
assert isinstance(runner.drafter, EagleProposer)
runner.drafter.validate_same_kv_cache_group(kv_cache_config)
if has_kv_transfer_group():
get_kv_transfer_group().register_kv_caches(kv_caches)
if runner.device.type == 'xpu':
get_kv_transfer_group().set_host_xfer_buffer_ops(
copy_kv_blocks)
if runner.dcp_world_size > 1:
layer_names = runner.attn_groups[0][0].layer_names
layers = get_layers_from_vllm_config(
runner.vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
layer_names,
)
for layer in layers.values():
layer_impl = cast(Any, layer).impl
assert layer_impl.need_to_return_lse_for_decode, (
"DCP requires attention impls to return"
" the softmax lse for decode, but the impl "
f"{layer_impl.__class__.__name__} "
"does not return the softmax lse for decode.")