[KVConnector][Core] Support cross-layer KV blocks (#27743)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
Or Ozeri 2025-11-20 20:09:59 +02:00 committed by GitHub
parent e5bfcb6a88
commit 647464719b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 453 additions and 90 deletions

View File

@ -19,6 +19,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector import (
)
from vllm.forward_context import ForwardContext
from vllm.utils.hashing import sha256
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.core.kv_cache_utils import (
BlockHash,
get_request_block_hasher,
@ -92,7 +93,7 @@ class MockOffloadingSpec(OffloadingSpec):
return self.manager
def get_handlers(
self, _
self, _, __
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]:
yield GPULoadStoreSpec, MockLoadStoreSpec, self.handler
yield MockLoadStoreSpec, GPULoadStoreSpec, self.handler
@ -138,7 +139,10 @@ class RequestRunner:
self.worker_connector = OffloadingConnector(vllm_config, KVConnectorRole.WORKER)
# register worker kv_caches to enable OffloadingWorker creations
self.worker_connector.register_kv_caches(kv_caches={"a": torch.empty(0)})
self.worker_connector.register_cross_layers_kv_cache(
kv_cache=torch.empty(0),
attn_backend=FlashAttentionBackend,
)
# extract connector of scheduler
scheduler_connector = self.scheduler.connector

View File

@ -12,8 +12,10 @@ from tqdm import tqdm
from vllm import LLM, SamplingParams, TokensPrompt
from vllm.config import KVEventsConfig, KVTransferConfig
from vllm.distributed.kv_events import BlockStored, KVEventBatch
from vllm.utils.system_utils import set_env_var
CPU_BLOCK_SIZES = [16, 48]
CPU_BLOCK_SIZES = [48]
ATTN_BACKENDS = ["FLASH_ATTN", "FLASHINFER"]
class MockSubscriber:
@ -63,8 +65,88 @@ class MockSubscriber:
self.sub.close()
def _latency_test(llm: LLM, subscriber: MockSubscriber):
sampling_params = SamplingParams(max_tokens=1)
num_times_cpu_better_than_cold = 0
num_tests = 10
total_cold_time = 0.0
total_gpu_hit_time = 0.0
total_cpu_hit_time = 0.0
prompt_token_ids = [0] * 10001
for i in tqdm(range(num_tests), desc="Running tests"):
prompt_token_ids[0] = i
prompts = [TokensPrompt(prompt_token_ids=prompt_token_ids)]
# run generation - this should trigger saving KV cache
start_time = time.time()
llm.generate(prompts, sampling_params, use_tqdm=False)
cold_time = time.time() - start_time
total_cold_time += cold_time
# run generation again - should hit the GPU prefix cache
start_time = time.time()
llm.generate(prompts, sampling_params, use_tqdm=False)
gpu_hit_time = time.time() - start_time
total_gpu_hit_time += gpu_hit_time
# reset prefix cache to avoid GPU hit.
llm.reset_prefix_cache()
assert subscriber.get_new_cpu_stored_events()
# run generation again - this should trigger loading from CPU
start_time = time.time()
llm.generate(prompts, sampling_params, use_tqdm=False)
cpu_hit_time = time.time() - start_time
total_cpu_hit_time += cpu_hit_time
if cpu_hit_time < cold_time:
num_times_cpu_better_than_cold += 1
print("Average times:")
print(f" Cold: {total_cold_time * 1000 / num_tests:.2f}ms")
print(f" GPU hit: {total_gpu_hit_time * 1000 / num_tests:.2f}ms")
print(f" CPU hit: {total_cpu_hit_time * 1000 / num_tests:.2f}ms")
assert num_times_cpu_better_than_cold >= 0.8 * num_tests
def _accuracy_test(llm: LLM, subscriber: MockSubscriber):
sampling_params = SamplingParams(max_tokens=1)
cpu_block_size = (
llm.llm_engine.vllm_config.kv_transfer_config.kv_connector_extra_config[
"block_size"
]
)
subscriber.get_new_cpu_stored_events()
# prepend prompt to be cpu block aligned
prompt = "Let's count to 10. One, two, three, four,"
while (
len(llm.generate(prompt, use_tqdm=False)[0].prompt_token_ids) % cpu_block_size
!= 0
):
prompt = ". " + prompt
assert subscriber.get_new_cpu_stored_events()
test_count = 100
success_count = 0
for i in range(test_count):
if (
llm.generate(prompt, sampling_params, use_tqdm=False)[0].outputs[0].text
== " five"
):
success_count += 1
assert success_count >= 0.5 * test_count
@pytest.mark.parametrize("cpu_block_size", CPU_BLOCK_SIZES)
def test_cpu_offloading(cpu_block_size: int) -> None:
@pytest.mark.parametrize("attn_backend", ATTN_BACKENDS)
def test_cpu_offloading(cpu_block_size: int, attn_backend: str) -> None:
"""
Tests OffloadingConnector with CPUOffloadingSpec.
"""
@ -92,61 +174,20 @@ def test_cpu_offloading(cpu_block_size: int) -> None:
topic="test",
)
llm = LLM(
model="meta-llama/Llama-3.2-1B-Instruct",
gpu_memory_utilization=0.5,
kv_events_config=kv_events_config,
kv_transfer_config=kv_transfer_config,
)
sampling_params = SamplingParams(temperature=0, max_tokens=1)
with set_env_var("VLLM_ATTENTION_BACKEND", attn_backend):
llm = LLM(
model="meta-llama/Llama-3.2-1B-Instruct",
gpu_memory_utilization=0.5,
kv_events_config=kv_events_config,
kv_transfer_config=kv_transfer_config,
)
events_endpoint = events_endpoint.replace("*", "127.0.0.1")
subscriber = MockSubscriber(events_endpoint, topic=kv_events_config.topic)
try:
num_times_cpu_better_than_cold = 0
num_tests = 10
total_cold_time = 0.0
total_gpu_hit_time = 0.0
total_cpu_hit_time = 0.0
prompt_token_ids = [0] * 10001
for i in tqdm(range(num_tests), desc="Running tests"):
prompt_token_ids[0] = i
prompts = [TokensPrompt(prompt_token_ids=prompt_token_ids)]
# run generation - this should trigger saving KV cache
start_time = time.time()
llm.generate(prompts, sampling_params, use_tqdm=False)
cold_time = time.time() - start_time
total_cold_time += cold_time
# run generation again - should hit the GPU prefix cache
start_time = time.time()
llm.generate(prompts, sampling_params, use_tqdm=False)
gpu_hit_time = time.time() - start_time
total_gpu_hit_time += gpu_hit_time
# reset prefix cache to avoid GPU hit.
llm.reset_prefix_cache()
assert subscriber.get_new_cpu_stored_events()
# run generation again - this should trigger loading from CPU
start_time = time.time()
llm.generate(prompts, sampling_params, use_tqdm=False)
cpu_hit_time = time.time() - start_time
total_cpu_hit_time += cpu_hit_time
if cpu_hit_time < cold_time:
num_times_cpu_better_than_cold += 1
print("Average times:")
print(f" Cold: {total_cold_time * 1000 / num_tests:.2f}ms")
print(f" GPU hit: {total_gpu_hit_time * 1000 / num_tests:.2f}ms")
print(f" CPU hit: {total_cpu_hit_time * 1000 / num_tests:.2f}ms")
assert num_times_cpu_better_than_cold >= 0.8 * num_tests
_latency_test(llm, subscriber)
_accuracy_test(llm, subscriber)
finally:
subscriber.close()
del llm

View File

@ -483,7 +483,10 @@ def test_kv_cache_stride_order(monkeypatch, model_runner):
# Permutation that gets you back to expected kv shape
for test_stride in ((1, 4, 0, 2, 3), (0, 1, 2, 3, 4)):
def rnd_stride_order(test_stride=test_stride):
def rnd_stride_order(
include_num_layers_dimension: bool = False, test_stride=test_stride
):
assert not include_num_layers_dimension
return test_stride
# Patch the attention backend class and re-trigger the KV cache creation

View File

@ -76,7 +76,34 @@ class AttentionBackend(ABC):
raise NotImplementedError
@staticmethod
def get_kv_cache_stride_order() -> tuple[int, ...]:
def get_kv_cache_stride_order(
include_num_layers_dimension: bool = False,
) -> tuple[int, ...]:
"""
Get the physical (memory layout) ordering of the kv cache dimensions.
e.g. if the KV cache shape is
[2, num_blocks, block_size, num_heads, head_size],
and get_kv_cache_stride_order returns (1, 3, 0, 2, 4) then the physical
ordering of dimensions is
[num_blocks, num_heads, 2, block_size, head_size].
If this function is unimplemented / raises NotImplementedError,
the physical layout of the KV cache will match the logical shape.
Args:
include_num_layers_dimension: if True, includes an additional
num_layers dimension, which is assumed to be prepended
to the logical KV cache shape.
With the above example, a return value (2, 4, 0, 1, 3, 5)
corresponds to
[num_blocks, num_heads, num_layers, 2, block_size, head_size].
If an additional dimension is NOT included in the returned
tuple, the physical layout will not include a layers dimension.
Returns:
A tuple of ints which is a permutation of range(len(shape)).
"""
raise NotImplementedError
@classmethod

View File

@ -38,7 +38,7 @@ The class provides the following primitives:
import enum
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable
from typing import TYPE_CHECKING, Any, Literal, Optional
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional
import torch
@ -47,7 +47,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
@ -142,6 +142,18 @@ class KVConnectorMetadata(ABC): # noqa: B024
class KVConnectorBase_V1(ABC):
"""
Base class for KV connectors.
Attributes:
prefer_cross_layer_blocks (bool): Indicates whether this connector
prefers KV blocks that hold KV data for all layers (for speeding
up KV data transfers).
Defaults to False.
"""
prefer_cross_layer_blocks: ClassVar[bool] = False
def __init__(
self,
vllm_config: "VllmConfig",
@ -226,6 +238,23 @@ class KVConnectorBase_V1(ABC):
"""
return
def register_cross_layers_kv_cache(
self, kv_cache: torch.Tensor, attn_backend: type["AttentionBackend"]
):
"""
Initialize with a single KV cache tensor used by all layers.
The first dimension should be num_layers.
This function will only be called for models with uniform layers,
and only if the prefers_cross_layer_blocks is set to True.
Only one of the functions
{register_kv_caches, register_cross_layers_kv_cache} will be called.
Args:
kv_cache: a cross-layers kv cache tensor
attn_backend: The attention backend that corresponds to all layers
"""
return
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
"""
Set the xPU-specific ops for copying KV between host and device.

View File

@ -4,12 +4,12 @@ from collections import defaultdict
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from itertools import islice
from typing import Any
from typing import Any, ClassVar
import torch
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.attention import Attention, AttentionBackend, AttentionMetadata
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorBase_V1,
@ -42,6 +42,8 @@ class OffloadingConnectorMetadata(KVConnectorMetadata):
class OffloadingConnector(KVConnectorBase_V1):
prefer_cross_layer_blocks: ClassVar[bool] = True
def __init__(
self,
vllm_config: VllmConfig,
@ -63,6 +65,12 @@ class OffloadingConnector(KVConnectorBase_V1):
assert self.connector_worker is not None
self.connector_worker.register_kv_caches(kv_caches)
def register_cross_layers_kv_cache(
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
):
assert self.connector_worker is not None
self.connector_worker.register_cross_layers_kv_cache(kv_cache, attn_backend)
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
assert self.connector_worker is not None
assert isinstance(self._connector_metadata, OffloadingConnectorMetadata)
@ -422,10 +430,35 @@ class OffloadingConnectorWorker:
self._job_counter = job_id + 1
return job_id
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
for src_cls, dst_cls, handler in self.spec.get_handlers(kv_caches):
def _register_handlers(
self,
kv_caches: dict[str, torch.Tensor],
attn_backends: dict[str, type[AttentionBackend]],
):
for src_cls, dst_cls, handler in self.spec.get_handlers(
kv_caches, attn_backends
):
self.worker.register_handler(src_cls, dst_cls, handler)
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
layer_names = list(kv_caches.keys())
layers = get_layers_from_vllm_config(
self.spec.vllm_config, Attention, layer_names
)
attn_backends = {
layer_name: layers[layer_name].get_attn_backend()
for layer_name in layer_names
}
self._register_handlers(kv_caches, attn_backends)
def register_cross_layers_kv_cache(
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
):
cross_layer_name = "ALL_LAYERS"
kv_caches = {cross_layer_name: kv_cache}
attn_backends = {cross_layer_name: attn_backend}
self._register_handlers(kv_caches, attn_backends)
def start_load_kv(self, metadata: OffloadingConnectorMetadata):
for req_id, transfer_spec in metadata.reqs_to_load.items():
job_id = self._generate_job_id()

View File

@ -99,12 +99,20 @@ class FlashAttentionBackend(AttentionBackend):
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def get_kv_cache_stride_order() -> tuple[int, ...]:
def get_kv_cache_stride_order(
include_num_layers_dimension: bool = False,
) -> tuple[int, ...]:
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout = get_kv_cache_layout()
if cache_layout == "NHD":
if cache_layout == "NHD" and include_num_layers_dimension:
# (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size)
return (2, 0, 1, 3, 4, 5)
elif cache_layout == "NHD":
stride_order = (0, 1, 2, 3, 4)
elif cache_layout == "HND" and include_num_layers_dimension:
# (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size)
return (2, 4, 0, 1, 3, 5)
elif cache_layout == "HND":
stride_order = (0, 1, 3, 2, 4)
else:

View File

@ -309,12 +309,20 @@ class FlashInferBackend(AttentionBackend):
return (num_blocks, 2, block_size, num_kv_heads, head_size)
@staticmethod
def get_kv_cache_stride_order() -> tuple[int, ...]:
def get_kv_cache_stride_order(
include_num_layers_dimension: bool = False,
) -> tuple[int, ...]:
# `stride_order` indicates the permutation that gets us from
# `get_kv_cache_shape` to the actual memory layout we want.
cache_layout = get_kv_cache_layout()
if cache_layout == "NHD":
if cache_layout == "NHD" and include_num_layers_dimension:
# (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size)
return (1, 0, 2, 3, 4, 5)
elif cache_layout == "NHD":
stride_order = (0, 1, 2, 3, 4)
elif cache_layout == "HND" and include_num_layers_dimension:
# (num_blocks, 2, num_kv_heads, num_layers, block_size, head_size)
return (1, 2, 4, 0, 3, 5)
elif cache_layout == "HND":
stride_order = (0, 1, 3, 2, 4)
else:

View File

@ -308,6 +308,15 @@ class MLACommonBackend(AttentionBackend):
) -> tuple[int, ...]:
return (num_blocks, block_size, head_size)
@staticmethod
def get_kv_cache_stride_order(
include_num_layers_dimension: bool = False,
) -> tuple[int, ...]:
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
# (num_blocks, num_layers, block_size, head_size)
return (1, 0, 2, 3) if include_num_layers_dimension else (0, 1, 2)
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [576]

View File

@ -48,7 +48,11 @@ class DeepseekV32IndexerBackend(AttentionBackend):
return (num_blocks, block_size, head_size)
@staticmethod
def get_kv_cache_stride_order() -> tuple[int, ...]:
def get_kv_cache_stride_order(
include_num_layers_dimension: bool = False,
) -> tuple[int, ...]:
if include_num_layers_dimension:
return (0, 1, 2, 3)
return (0, 1, 2)

View File

@ -4,8 +4,8 @@ from collections.abc import Iterator
import torch
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.attention import AttentionBackend
from vllm.config import VllmConfig
from vllm.platforms import current_platform
from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager
from vllm.v1.kv_offload.arc_manager import ARCOffloadingManager
@ -63,7 +63,9 @@ class CPUOffloadingSpec(OffloadingSpec):
return self._manager
def get_handlers(
self, kv_caches: dict[str, torch.Tensor]
self,
kv_caches: dict[str, torch.Tensor],
attn_backends: dict[str, type[AttentionBackend]],
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]:
if not self._handler:
if not current_platform.is_cuda_alike():
@ -71,15 +73,6 @@ class CPUOffloadingSpec(OffloadingSpec):
"CPU Offloading is currently only supported on CUDA-alike GPUs"
)
layer_names = list(kv_caches.keys())
layers = get_layers_from_vllm_config(
self.vllm_config, AttentionLayerBase, layer_names
)
attn_backends = {
layer_name: layers[layer_name].get_attn_backend()
for layer_name in layer_names
}
self._handler = CpuGpuOffloadingHandler(
attn_backends=attn_backends,
gpu_block_size=self.gpu_block_size,

View File

@ -11,6 +11,7 @@ from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager
from vllm.v1.kv_offload.worker.worker import OffloadingHandler
if TYPE_CHECKING:
from vllm.attention import AttentionBackend
from vllm.config import VllmConfig
logger = init_logger(__name__)
@ -48,13 +49,16 @@ class OffloadingSpec(ABC):
@abstractmethod
def get_handlers(
self, kv_caches: dict[str, torch.Tensor]
self,
kv_caches: dict[str, torch.Tensor],
attn_backends: dict[str, type["AttentionBackend"]],
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]:
"""
Get offloading handlers along with their respective src and dst types.
Args:
kv_caches: A dictionary of layer_name -> gpu_kv_cache tensor.
attn_backends: A dictionary of layer_name -> AttentionBackend.
Yields:
Tuples of (src_type, dst_type, offloading_handler).

View File

@ -83,10 +83,18 @@ class CpuGpuOffloadingHandler(OffloadingHandler):
self.gpu_tensors.append(gpu_tensor)
gpu_shape = gpu_tensor.shape
test_shape = attn_backends[layer_name].get_kv_cache_shape(
attn_backend = attn_backends[layer_name]
test_shape = attn_backend.get_kv_cache_shape(
num_blocks=1234, block_size=16, num_kv_heads=8, head_size=256
)
if test_shape[0] == 1234:
if len(gpu_shape) != len(test_shape):
# cross-layers tensor
# shape is (num_blocks, ...)
assert len(gpu_shape) == len(test_shape) + 1
num_blocks_idx = 0
self.kv_dim_before_num_blocks.append(False)
elif test_shape[0] == 1234:
# shape is (num_blocks, ...)
num_blocks_idx = 0
self.kv_dim_before_num_blocks.append(False)

View File

@ -349,6 +349,9 @@ class GPUModelRunner(
# self.model: nn.Module # Set after load_model
# Initialize in initialize_kv_cache
self.kv_caches: list[torch.Tensor] = []
# Initialize in initialize_kv_cache_tensors
self.cross_layers_kv_cache: torch.Tensor | None = None
self.cross_layers_attn_backend: type[AttentionBackend] | None = None
# indexes: [kv_cache_group_id][attn_group]
self.attn_groups: list[list[AttentionGroup]] = []
# self.kv_cache_config: KVCacheConfig
@ -4930,12 +4933,30 @@ class GPUModelRunner(
Dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache.
"""
# Initialize the memory buffer for KV cache
kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config)
# Change the memory buffer to the desired shape
kv_caches = self._reshape_kv_cache_tensors(
kv_cache_config, kv_cache_raw_tensors, kernel_block_sizes
)
# Try creating KV caches optimized for kv-connector transfers
cache_dtype = self.cache_config.cache_dtype
if self.use_uniform_kv_cache(self.attn_groups, cache_dtype):
kv_caches, cross_layers_kv_cache, attn_backend = (
self.allocate_uniform_kv_caches(
kv_cache_config,
self.attn_groups,
cache_dtype,
self.device,
kernel_block_sizes,
)
)
self.cross_layers_kv_cache = cross_layers_kv_cache
self.cross_layers_attn_backend = attn_backend
else:
# Fallback to the general case
# Initialize the memory buffer for KV cache
kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config)
# Change the memory buffer to the desired shape
kv_caches = self._reshape_kv_cache_tensors(
kv_cache_config, kv_cache_raw_tensors, kernel_block_sizes
)
# Set up cross-layer KV cache sharing
for layer_name, target_layer_name in self.shared_kv_cache_layers.items():
@ -5017,7 +5038,13 @@ class GPUModelRunner(
if has_kv_transfer_group():
kv_transfer_group = get_kv_transfer_group()
kv_transfer_group.register_kv_caches(kv_caches)
if self.cross_layers_kv_cache is not None:
assert self.cross_layers_attn_backend is not None
kv_transfer_group.register_cross_layers_kv_cache(
self.cross_layers_kv_cache, self.cross_layers_attn_backend
)
else:
kv_transfer_group.register_kv_caches(kv_caches)
kv_transfer_group.set_host_xfer_buffer_ops(copy_kv_blocks)
if self.dcp_world_size > 1:

View File

@ -11,7 +11,11 @@ from typing import (
TYPE_CHECKING, # noqa: UP035
)
import torch
from vllm.attention import AttentionBackend
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.distributed.kv_transfer import (
ensure_kv_transfer_shutdown,
get_kv_transfer_group,
@ -21,11 +25,13 @@ from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.logger import init_logger
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig
from vllm.v1.outputs import (
EMPTY_MODEL_RUNNER_OUTPUT,
KVConnectorOutput,
ModelRunnerOutput,
)
from vllm.v1.worker.utils import AttentionGroup
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
@ -142,3 +148,162 @@ class KVConnectorModelRunnerMixin:
if has_kv_transfer_group():
return get_kv_transfer_group().get_kv_connector_stats()
return None
@staticmethod
def use_uniform_kv_cache(
attn_groups: list[list[AttentionGroup]],
cache_dtype: CacheDType,
) -> bool:
"""
Determines whether a uniform KV layout should be used.
A uniform layout means all layers KV caches will share the same
underlying tensor, where for a given block number, the respective
KV data for all layers will be contiguous.
This will allow efficient KV transfer of per-block KV data for all
layers at once.
Note this layout will only be applied given 3 conditions:
1. The KV Cache config contains just a single group where all layers
have the same page size.
2. A KV connector is configured, and the KV connector instance prefers
to use this layout (prefer_cross_layer_blocks() returns True)
2. The flash attention backend supports this layout
(get_kv_cache_stride_order(True) includes a placement for a
num_layers dimension)
Note that the actual placement of the num_layers dimensions
in the unified layers tensors will be determined by the attention
backend.
Thus, the layers KV data may still not be contiguous per block
if the attention backend does not support it.
Args:
attn_groups: The list of attention groups for this model
cache_dtype: The KV cache dtype
Returns:
True if we should use a uniform KV cache layout.
"""
if not has_kv_transfer_group():
return False
if not get_kv_transfer_group().prefer_cross_layer_blocks:
return False
if len(attn_groups) != 1 or len(attn_groups[0]) != 1:
return False
attn_group = attn_groups[0][0]
kv_cache_spec = attn_group.kv_cache_spec
if not isinstance(kv_cache_spec, AttentionSpec):
return False
attn_backend = attn_group.backend
kv_cache_shape = attn_backend.get_kv_cache_shape(
1234,
kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
cache_dtype_str=cache_dtype,
)
try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
include_num_layers_dimension=True
)
except (AttributeError, NotImplementedError):
return False
# check that attention backend include a layers dimension
return len(kv_cache_stride_order) == len(kv_cache_shape) + 1
@staticmethod
def allocate_uniform_kv_caches(
kv_cache_config: KVCacheConfig,
attn_groups: list[list[AttentionGroup]],
cache_dtype: CacheDType,
device: torch.device,
kernel_block_sizes: list[int],
) -> tuple[dict[str, torch.Tensor], torch.Tensor, type[AttentionBackend]]:
"""
Initializes and reshapes KV caches for the simple case where all
layers have the same layout.
This function assumes use_uniform_kv_cache() returned True.
Args:
kv_cache_config: The KV cache config
attn_groups: The list of attention groups for this model
cache_dtype: The KV cache dtype
device: The torch device to allocate on.
kernel_block_sizes: The kernel block sizes for each KV cache group.
Returns:
A tuple (kv_caches, cross_layers_kv_cache, attn_backend) where:
kv_caches is a dict mapping between layer names to their
corresponding memory buffer for KV cache.
cross_layers_kv_cache is the cross layers kv cache tensor
attn_backend is the attention backend matching this tensor
"""
attn_group = attn_groups[0][0]
kv_cache_spec = attn_group.kv_cache_spec
assert isinstance(kv_cache_spec, AttentionSpec)
tensor_sizes = set(
kv_cache_tensor.size for kv_cache_tensor in kv_cache_config.kv_cache_tensors
)
assert len(tensor_sizes) == 1
tensor_size = tensor_sizes.pop()
page_size = kv_cache_spec.page_size_bytes
assert tensor_size % page_size == 0
num_blocks = tensor_size // page_size
num_layers = len(kv_cache_config.kv_cache_tensors)
total_size = tensor_size * num_layers
assert len(kernel_block_sizes) == 1
kernel_block_size = kernel_block_sizes[0]
num_blocks_per_kv_block = kv_cache_spec.block_size // kernel_block_size
kernel_num_blocks = num_blocks * num_blocks_per_kv_block
attn_backend = attn_group.backend
kv_cache_shape = attn_backend.get_kv_cache_shape(
kernel_num_blocks,
kernel_block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
cache_dtype_str=cache_dtype,
)
# prepend a num_layers dimension into the shape
kv_cache_shape = (num_layers,) + kv_cache_shape
try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
include_num_layers_dimension=True
)
assert len(kv_cache_stride_order) == len(kv_cache_shape)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
logger.info("Allocating a cross layer KV cache of shape %s", kv_cache_shape)
# allocate one contiguous buffer for all layers
cross_layers_kv_cache = (
torch.zeros(total_size, dtype=torch.int8, device=device)
.view(kv_cache_spec.dtype)
.view(kv_cache_shape)
)
# Maintain original KV shape view.
inv_order = [
kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order))
]
permuted_kv_cache = cross_layers_kv_cache.permute(*inv_order)
kv_caches = {}
for i, kv_cache_tensor in enumerate(kv_cache_config.kv_cache_tensors):
tensor = permuted_kv_cache[i]
for layer_name in kv_cache_tensor.shared_by:
kv_caches[layer_name] = tensor
return kv_caches, cross_layers_kv_cache, attn_backend