mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-24 02:54:28 +08:00
[KVConnector][Core] Support cross-layer KV blocks (#27743)
Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
parent
e5bfcb6a88
commit
647464719b
@ -19,6 +19,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector import (
|
||||
)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.utils.hashing import sha256
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
||||
from vllm.v1.core.kv_cache_utils import (
|
||||
BlockHash,
|
||||
get_request_block_hasher,
|
||||
@ -92,7 +93,7 @@ class MockOffloadingSpec(OffloadingSpec):
|
||||
return self.manager
|
||||
|
||||
def get_handlers(
|
||||
self, _
|
||||
self, _, __
|
||||
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]:
|
||||
yield GPULoadStoreSpec, MockLoadStoreSpec, self.handler
|
||||
yield MockLoadStoreSpec, GPULoadStoreSpec, self.handler
|
||||
@ -138,7 +139,10 @@ class RequestRunner:
|
||||
self.worker_connector = OffloadingConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
|
||||
# register worker kv_caches to enable OffloadingWorker creations
|
||||
self.worker_connector.register_kv_caches(kv_caches={"a": torch.empty(0)})
|
||||
self.worker_connector.register_cross_layers_kv_cache(
|
||||
kv_cache=torch.empty(0),
|
||||
attn_backend=FlashAttentionBackend,
|
||||
)
|
||||
|
||||
# extract connector of scheduler
|
||||
scheduler_connector = self.scheduler.connector
|
||||
|
||||
@ -12,8 +12,10 @@ from tqdm import tqdm
|
||||
from vllm import LLM, SamplingParams, TokensPrompt
|
||||
from vllm.config import KVEventsConfig, KVTransferConfig
|
||||
from vllm.distributed.kv_events import BlockStored, KVEventBatch
|
||||
from vllm.utils.system_utils import set_env_var
|
||||
|
||||
CPU_BLOCK_SIZES = [16, 48]
|
||||
CPU_BLOCK_SIZES = [48]
|
||||
ATTN_BACKENDS = ["FLASH_ATTN", "FLASHINFER"]
|
||||
|
||||
|
||||
class MockSubscriber:
|
||||
@ -63,8 +65,88 @@ class MockSubscriber:
|
||||
self.sub.close()
|
||||
|
||||
|
||||
def _latency_test(llm: LLM, subscriber: MockSubscriber):
|
||||
sampling_params = SamplingParams(max_tokens=1)
|
||||
|
||||
num_times_cpu_better_than_cold = 0
|
||||
num_tests = 10
|
||||
total_cold_time = 0.0
|
||||
total_gpu_hit_time = 0.0
|
||||
total_cpu_hit_time = 0.0
|
||||
prompt_token_ids = [0] * 10001
|
||||
for i in tqdm(range(num_tests), desc="Running tests"):
|
||||
prompt_token_ids[0] = i
|
||||
prompts = [TokensPrompt(prompt_token_ids=prompt_token_ids)]
|
||||
|
||||
# run generation - this should trigger saving KV cache
|
||||
start_time = time.time()
|
||||
llm.generate(prompts, sampling_params, use_tqdm=False)
|
||||
cold_time = time.time() - start_time
|
||||
total_cold_time += cold_time
|
||||
|
||||
# run generation again - should hit the GPU prefix cache
|
||||
start_time = time.time()
|
||||
llm.generate(prompts, sampling_params, use_tqdm=False)
|
||||
gpu_hit_time = time.time() - start_time
|
||||
total_gpu_hit_time += gpu_hit_time
|
||||
|
||||
# reset prefix cache to avoid GPU hit.
|
||||
llm.reset_prefix_cache()
|
||||
|
||||
assert subscriber.get_new_cpu_stored_events()
|
||||
|
||||
# run generation again - this should trigger loading from CPU
|
||||
start_time = time.time()
|
||||
llm.generate(prompts, sampling_params, use_tqdm=False)
|
||||
cpu_hit_time = time.time() - start_time
|
||||
total_cpu_hit_time += cpu_hit_time
|
||||
|
||||
if cpu_hit_time < cold_time:
|
||||
num_times_cpu_better_than_cold += 1
|
||||
|
||||
print("Average times:")
|
||||
print(f" Cold: {total_cold_time * 1000 / num_tests:.2f}ms")
|
||||
print(f" GPU hit: {total_gpu_hit_time * 1000 / num_tests:.2f}ms")
|
||||
print(f" CPU hit: {total_cpu_hit_time * 1000 / num_tests:.2f}ms")
|
||||
|
||||
assert num_times_cpu_better_than_cold >= 0.8 * num_tests
|
||||
|
||||
|
||||
def _accuracy_test(llm: LLM, subscriber: MockSubscriber):
|
||||
sampling_params = SamplingParams(max_tokens=1)
|
||||
cpu_block_size = (
|
||||
llm.llm_engine.vllm_config.kv_transfer_config.kv_connector_extra_config[
|
||||
"block_size"
|
||||
]
|
||||
)
|
||||
|
||||
subscriber.get_new_cpu_stored_events()
|
||||
|
||||
# prepend prompt to be cpu block aligned
|
||||
prompt = "Let's count to 10. One, two, three, four,"
|
||||
while (
|
||||
len(llm.generate(prompt, use_tqdm=False)[0].prompt_token_ids) % cpu_block_size
|
||||
!= 0
|
||||
):
|
||||
prompt = ". " + prompt
|
||||
|
||||
assert subscriber.get_new_cpu_stored_events()
|
||||
|
||||
test_count = 100
|
||||
success_count = 0
|
||||
for i in range(test_count):
|
||||
if (
|
||||
llm.generate(prompt, sampling_params, use_tqdm=False)[0].outputs[0].text
|
||||
== " five"
|
||||
):
|
||||
success_count += 1
|
||||
|
||||
assert success_count >= 0.5 * test_count
|
||||
|
||||
|
||||
@pytest.mark.parametrize("cpu_block_size", CPU_BLOCK_SIZES)
|
||||
def test_cpu_offloading(cpu_block_size: int) -> None:
|
||||
@pytest.mark.parametrize("attn_backend", ATTN_BACKENDS)
|
||||
def test_cpu_offloading(cpu_block_size: int, attn_backend: str) -> None:
|
||||
"""
|
||||
Tests OffloadingConnector with CPUOffloadingSpec.
|
||||
"""
|
||||
@ -92,61 +174,20 @@ def test_cpu_offloading(cpu_block_size: int) -> None:
|
||||
topic="test",
|
||||
)
|
||||
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3.2-1B-Instruct",
|
||||
gpu_memory_utilization=0.5,
|
||||
kv_events_config=kv_events_config,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0, max_tokens=1)
|
||||
with set_env_var("VLLM_ATTENTION_BACKEND", attn_backend):
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3.2-1B-Instruct",
|
||||
gpu_memory_utilization=0.5,
|
||||
kv_events_config=kv_events_config,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
)
|
||||
|
||||
events_endpoint = events_endpoint.replace("*", "127.0.0.1")
|
||||
subscriber = MockSubscriber(events_endpoint, topic=kv_events_config.topic)
|
||||
|
||||
try:
|
||||
num_times_cpu_better_than_cold = 0
|
||||
num_tests = 10
|
||||
total_cold_time = 0.0
|
||||
total_gpu_hit_time = 0.0
|
||||
total_cpu_hit_time = 0.0
|
||||
prompt_token_ids = [0] * 10001
|
||||
for i in tqdm(range(num_tests), desc="Running tests"):
|
||||
prompt_token_ids[0] = i
|
||||
prompts = [TokensPrompt(prompt_token_ids=prompt_token_ids)]
|
||||
|
||||
# run generation - this should trigger saving KV cache
|
||||
start_time = time.time()
|
||||
llm.generate(prompts, sampling_params, use_tqdm=False)
|
||||
cold_time = time.time() - start_time
|
||||
total_cold_time += cold_time
|
||||
|
||||
# run generation again - should hit the GPU prefix cache
|
||||
start_time = time.time()
|
||||
llm.generate(prompts, sampling_params, use_tqdm=False)
|
||||
gpu_hit_time = time.time() - start_time
|
||||
total_gpu_hit_time += gpu_hit_time
|
||||
|
||||
# reset prefix cache to avoid GPU hit.
|
||||
llm.reset_prefix_cache()
|
||||
|
||||
assert subscriber.get_new_cpu_stored_events()
|
||||
|
||||
# run generation again - this should trigger loading from CPU
|
||||
start_time = time.time()
|
||||
llm.generate(prompts, sampling_params, use_tqdm=False)
|
||||
cpu_hit_time = time.time() - start_time
|
||||
total_cpu_hit_time += cpu_hit_time
|
||||
|
||||
if cpu_hit_time < cold_time:
|
||||
num_times_cpu_better_than_cold += 1
|
||||
|
||||
print("Average times:")
|
||||
print(f" Cold: {total_cold_time * 1000 / num_tests:.2f}ms")
|
||||
print(f" GPU hit: {total_gpu_hit_time * 1000 / num_tests:.2f}ms")
|
||||
print(f" CPU hit: {total_cpu_hit_time * 1000 / num_tests:.2f}ms")
|
||||
|
||||
assert num_times_cpu_better_than_cold >= 0.8 * num_tests
|
||||
_latency_test(llm, subscriber)
|
||||
_accuracy_test(llm, subscriber)
|
||||
finally:
|
||||
subscriber.close()
|
||||
del llm
|
||||
|
||||
@ -483,7 +483,10 @@ def test_kv_cache_stride_order(monkeypatch, model_runner):
|
||||
# Permutation that gets you back to expected kv shape
|
||||
for test_stride in ((1, 4, 0, 2, 3), (0, 1, 2, 3, 4)):
|
||||
|
||||
def rnd_stride_order(test_stride=test_stride):
|
||||
def rnd_stride_order(
|
||||
include_num_layers_dimension: bool = False, test_stride=test_stride
|
||||
):
|
||||
assert not include_num_layers_dimension
|
||||
return test_stride
|
||||
|
||||
# Patch the attention backend class and re-trigger the KV cache creation
|
||||
|
||||
@ -76,7 +76,34 @@ class AttentionBackend(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_stride_order() -> tuple[int, ...]:
|
||||
def get_kv_cache_stride_order(
|
||||
include_num_layers_dimension: bool = False,
|
||||
) -> tuple[int, ...]:
|
||||
"""
|
||||
Get the physical (memory layout) ordering of the kv cache dimensions.
|
||||
e.g. if the KV cache shape is
|
||||
[2, num_blocks, block_size, num_heads, head_size],
|
||||
and get_kv_cache_stride_order returns (1, 3, 0, 2, 4) then the physical
|
||||
ordering of dimensions is
|
||||
[num_blocks, num_heads, 2, block_size, head_size].
|
||||
|
||||
If this function is unimplemented / raises NotImplementedError,
|
||||
the physical layout of the KV cache will match the logical shape.
|
||||
|
||||
Args:
|
||||
include_num_layers_dimension: if True, includes an additional
|
||||
num_layers dimension, which is assumed to be prepended
|
||||
to the logical KV cache shape.
|
||||
With the above example, a return value (2, 4, 0, 1, 3, 5)
|
||||
corresponds to
|
||||
[num_blocks, num_heads, num_layers, 2, block_size, head_size].
|
||||
|
||||
If an additional dimension is NOT included in the returned
|
||||
tuple, the physical layout will not include a layers dimension.
|
||||
|
||||
Returns:
|
||||
A tuple of ints which is a permutation of range(len(shape)).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -38,7 +38,7 @@ The class provides the following primitives:
|
||||
import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -47,7 +47,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
@ -142,6 +142,18 @@ class KVConnectorMetadata(ABC): # noqa: B024
|
||||
|
||||
|
||||
class KVConnectorBase_V1(ABC):
|
||||
"""
|
||||
Base class for KV connectors.
|
||||
|
||||
Attributes:
|
||||
prefer_cross_layer_blocks (bool): Indicates whether this connector
|
||||
prefers KV blocks that hold KV data for all layers (for speeding
|
||||
up KV data transfers).
|
||||
Defaults to False.
|
||||
"""
|
||||
|
||||
prefer_cross_layer_blocks: ClassVar[bool] = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
@ -226,6 +238,23 @@ class KVConnectorBase_V1(ABC):
|
||||
"""
|
||||
return
|
||||
|
||||
def register_cross_layers_kv_cache(
|
||||
self, kv_cache: torch.Tensor, attn_backend: type["AttentionBackend"]
|
||||
):
|
||||
"""
|
||||
Initialize with a single KV cache tensor used by all layers.
|
||||
The first dimension should be num_layers.
|
||||
This function will only be called for models with uniform layers,
|
||||
and only if the prefers_cross_layer_blocks is set to True.
|
||||
Only one of the functions
|
||||
{register_kv_caches, register_cross_layers_kv_cache} will be called.
|
||||
|
||||
Args:
|
||||
kv_cache: a cross-layers kv cache tensor
|
||||
attn_backend: The attention backend that corresponds to all layers
|
||||
"""
|
||||
return
|
||||
|
||||
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
|
||||
"""
|
||||
Set the xPU-specific ops for copying KV between host and device.
|
||||
|
||||
@ -4,12 +4,12 @@ from collections import defaultdict
|
||||
from collections.abc import Iterable, Iterator
|
||||
from dataclasses import dataclass
|
||||
from itertools import islice
|
||||
from typing import Any
|
||||
from typing import Any, ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.attention import Attention, AttentionBackend, AttentionMetadata
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (
|
||||
KVConnectorBase_V1,
|
||||
@ -42,6 +42,8 @@ class OffloadingConnectorMetadata(KVConnectorMetadata):
|
||||
|
||||
|
||||
class OffloadingConnector(KVConnectorBase_V1):
|
||||
prefer_cross_layer_blocks: ClassVar[bool] = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
@ -63,6 +65,12 @@ class OffloadingConnector(KVConnectorBase_V1):
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.register_kv_caches(kv_caches)
|
||||
|
||||
def register_cross_layers_kv_cache(
|
||||
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
|
||||
):
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.register_cross_layers_kv_cache(kv_cache, attn_backend)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
|
||||
assert self.connector_worker is not None
|
||||
assert isinstance(self._connector_metadata, OffloadingConnectorMetadata)
|
||||
@ -422,10 +430,35 @@ class OffloadingConnectorWorker:
|
||||
self._job_counter = job_id + 1
|
||||
return job_id
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
for src_cls, dst_cls, handler in self.spec.get_handlers(kv_caches):
|
||||
def _register_handlers(
|
||||
self,
|
||||
kv_caches: dict[str, torch.Tensor],
|
||||
attn_backends: dict[str, type[AttentionBackend]],
|
||||
):
|
||||
for src_cls, dst_cls, handler in self.spec.get_handlers(
|
||||
kv_caches, attn_backends
|
||||
):
|
||||
self.worker.register_handler(src_cls, dst_cls, handler)
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
layer_names = list(kv_caches.keys())
|
||||
layers = get_layers_from_vllm_config(
|
||||
self.spec.vllm_config, Attention, layer_names
|
||||
)
|
||||
attn_backends = {
|
||||
layer_name: layers[layer_name].get_attn_backend()
|
||||
for layer_name in layer_names
|
||||
}
|
||||
self._register_handlers(kv_caches, attn_backends)
|
||||
|
||||
def register_cross_layers_kv_cache(
|
||||
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
|
||||
):
|
||||
cross_layer_name = "ALL_LAYERS"
|
||||
kv_caches = {cross_layer_name: kv_cache}
|
||||
attn_backends = {cross_layer_name: attn_backend}
|
||||
self._register_handlers(kv_caches, attn_backends)
|
||||
|
||||
def start_load_kv(self, metadata: OffloadingConnectorMetadata):
|
||||
for req_id, transfer_spec in metadata.reqs_to_load.items():
|
||||
job_id = self._generate_job_id()
|
||||
|
||||
@ -99,12 +99,20 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_stride_order() -> tuple[int, ...]:
|
||||
def get_kv_cache_stride_order(
|
||||
include_num_layers_dimension: bool = False,
|
||||
) -> tuple[int, ...]:
|
||||
# `stride_order` indicates the permutation that gets
|
||||
# us from `get_kv_cache_shape` to the actual memory layout we want.
|
||||
cache_layout = get_kv_cache_layout()
|
||||
if cache_layout == "NHD":
|
||||
if cache_layout == "NHD" and include_num_layers_dimension:
|
||||
# (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size)
|
||||
return (2, 0, 1, 3, 4, 5)
|
||||
elif cache_layout == "NHD":
|
||||
stride_order = (0, 1, 2, 3, 4)
|
||||
elif cache_layout == "HND" and include_num_layers_dimension:
|
||||
# (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size)
|
||||
return (2, 4, 0, 1, 3, 5)
|
||||
elif cache_layout == "HND":
|
||||
stride_order = (0, 1, 3, 2, 4)
|
||||
else:
|
||||
|
||||
@ -309,12 +309,20 @@ class FlashInferBackend(AttentionBackend):
|
||||
return (num_blocks, 2, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_stride_order() -> tuple[int, ...]:
|
||||
def get_kv_cache_stride_order(
|
||||
include_num_layers_dimension: bool = False,
|
||||
) -> tuple[int, ...]:
|
||||
# `stride_order` indicates the permutation that gets us from
|
||||
# `get_kv_cache_shape` to the actual memory layout we want.
|
||||
cache_layout = get_kv_cache_layout()
|
||||
if cache_layout == "NHD":
|
||||
if cache_layout == "NHD" and include_num_layers_dimension:
|
||||
# (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size)
|
||||
return (1, 0, 2, 3, 4, 5)
|
||||
elif cache_layout == "NHD":
|
||||
stride_order = (0, 1, 2, 3, 4)
|
||||
elif cache_layout == "HND" and include_num_layers_dimension:
|
||||
# (num_blocks, 2, num_kv_heads, num_layers, block_size, head_size)
|
||||
return (1, 2, 4, 0, 3, 5)
|
||||
elif cache_layout == "HND":
|
||||
stride_order = (0, 1, 3, 2, 4)
|
||||
else:
|
||||
|
||||
@ -308,6 +308,15 @@ class MLACommonBackend(AttentionBackend):
|
||||
) -> tuple[int, ...]:
|
||||
return (num_blocks, block_size, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_stride_order(
|
||||
include_num_layers_dimension: bool = False,
|
||||
) -> tuple[int, ...]:
|
||||
# `stride_order` indicates the permutation that gets
|
||||
# us from `get_kv_cache_shape` to the actual memory layout we want.
|
||||
# (num_blocks, num_layers, block_size, head_size)
|
||||
return (1, 0, 2, 3) if include_num_layers_dimension else (0, 1, 2)
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [576]
|
||||
|
||||
@ -48,7 +48,11 @@ class DeepseekV32IndexerBackend(AttentionBackend):
|
||||
return (num_blocks, block_size, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_stride_order() -> tuple[int, ...]:
|
||||
def get_kv_cache_stride_order(
|
||||
include_num_layers_dimension: bool = False,
|
||||
) -> tuple[int, ...]:
|
||||
if include_num_layers_dimension:
|
||||
return (0, 1, 2, 3)
|
||||
return (0, 1, 2)
|
||||
|
||||
|
||||
|
||||
@ -4,8 +4,8 @@ from collections.abc import Iterator
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.attention import AttentionBackend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager
|
||||
from vllm.v1.kv_offload.arc_manager import ARCOffloadingManager
|
||||
@ -63,7 +63,9 @@ class CPUOffloadingSpec(OffloadingSpec):
|
||||
return self._manager
|
||||
|
||||
def get_handlers(
|
||||
self, kv_caches: dict[str, torch.Tensor]
|
||||
self,
|
||||
kv_caches: dict[str, torch.Tensor],
|
||||
attn_backends: dict[str, type[AttentionBackend]],
|
||||
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]:
|
||||
if not self._handler:
|
||||
if not current_platform.is_cuda_alike():
|
||||
@ -71,15 +73,6 @@ class CPUOffloadingSpec(OffloadingSpec):
|
||||
"CPU Offloading is currently only supported on CUDA-alike GPUs"
|
||||
)
|
||||
|
||||
layer_names = list(kv_caches.keys())
|
||||
layers = get_layers_from_vllm_config(
|
||||
self.vllm_config, AttentionLayerBase, layer_names
|
||||
)
|
||||
attn_backends = {
|
||||
layer_name: layers[layer_name].get_attn_backend()
|
||||
for layer_name in layer_names
|
||||
}
|
||||
|
||||
self._handler = CpuGpuOffloadingHandler(
|
||||
attn_backends=attn_backends,
|
||||
gpu_block_size=self.gpu_block_size,
|
||||
|
||||
@ -11,6 +11,7 @@ from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager
|
||||
from vllm.v1.kv_offload.worker.worker import OffloadingHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention import AttentionBackend
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -48,13 +49,16 @@ class OffloadingSpec(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def get_handlers(
|
||||
self, kv_caches: dict[str, torch.Tensor]
|
||||
self,
|
||||
kv_caches: dict[str, torch.Tensor],
|
||||
attn_backends: dict[str, type["AttentionBackend"]],
|
||||
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]:
|
||||
"""
|
||||
Get offloading handlers along with their respective src and dst types.
|
||||
|
||||
Args:
|
||||
kv_caches: A dictionary of layer_name -> gpu_kv_cache tensor.
|
||||
attn_backends: A dictionary of layer_name -> AttentionBackend.
|
||||
|
||||
Yields:
|
||||
Tuples of (src_type, dst_type, offloading_handler).
|
||||
|
||||
@ -83,10 +83,18 @@ class CpuGpuOffloadingHandler(OffloadingHandler):
|
||||
self.gpu_tensors.append(gpu_tensor)
|
||||
|
||||
gpu_shape = gpu_tensor.shape
|
||||
test_shape = attn_backends[layer_name].get_kv_cache_shape(
|
||||
attn_backend = attn_backends[layer_name]
|
||||
test_shape = attn_backend.get_kv_cache_shape(
|
||||
num_blocks=1234, block_size=16, num_kv_heads=8, head_size=256
|
||||
)
|
||||
if test_shape[0] == 1234:
|
||||
|
||||
if len(gpu_shape) != len(test_shape):
|
||||
# cross-layers tensor
|
||||
# shape is (num_blocks, ...)
|
||||
assert len(gpu_shape) == len(test_shape) + 1
|
||||
num_blocks_idx = 0
|
||||
self.kv_dim_before_num_blocks.append(False)
|
||||
elif test_shape[0] == 1234:
|
||||
# shape is (num_blocks, ...)
|
||||
num_blocks_idx = 0
|
||||
self.kv_dim_before_num_blocks.append(False)
|
||||
|
||||
@ -349,6 +349,9 @@ class GPUModelRunner(
|
||||
# self.model: nn.Module # Set after load_model
|
||||
# Initialize in initialize_kv_cache
|
||||
self.kv_caches: list[torch.Tensor] = []
|
||||
# Initialize in initialize_kv_cache_tensors
|
||||
self.cross_layers_kv_cache: torch.Tensor | None = None
|
||||
self.cross_layers_attn_backend: type[AttentionBackend] | None = None
|
||||
# indexes: [kv_cache_group_id][attn_group]
|
||||
self.attn_groups: list[list[AttentionGroup]] = []
|
||||
# self.kv_cache_config: KVCacheConfig
|
||||
@ -4930,12 +4933,30 @@ class GPUModelRunner(
|
||||
Dict[str, torch.Tensor]: A map between layer names to their
|
||||
corresponding memory buffer for KV cache.
|
||||
"""
|
||||
# Initialize the memory buffer for KV cache
|
||||
kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config)
|
||||
# Change the memory buffer to the desired shape
|
||||
kv_caches = self._reshape_kv_cache_tensors(
|
||||
kv_cache_config, kv_cache_raw_tensors, kernel_block_sizes
|
||||
)
|
||||
|
||||
# Try creating KV caches optimized for kv-connector transfers
|
||||
cache_dtype = self.cache_config.cache_dtype
|
||||
if self.use_uniform_kv_cache(self.attn_groups, cache_dtype):
|
||||
kv_caches, cross_layers_kv_cache, attn_backend = (
|
||||
self.allocate_uniform_kv_caches(
|
||||
kv_cache_config,
|
||||
self.attn_groups,
|
||||
cache_dtype,
|
||||
self.device,
|
||||
kernel_block_sizes,
|
||||
)
|
||||
)
|
||||
self.cross_layers_kv_cache = cross_layers_kv_cache
|
||||
self.cross_layers_attn_backend = attn_backend
|
||||
else:
|
||||
# Fallback to the general case
|
||||
# Initialize the memory buffer for KV cache
|
||||
kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config)
|
||||
|
||||
# Change the memory buffer to the desired shape
|
||||
kv_caches = self._reshape_kv_cache_tensors(
|
||||
kv_cache_config, kv_cache_raw_tensors, kernel_block_sizes
|
||||
)
|
||||
|
||||
# Set up cross-layer KV cache sharing
|
||||
for layer_name, target_layer_name in self.shared_kv_cache_layers.items():
|
||||
@ -5017,7 +5038,13 @@ class GPUModelRunner(
|
||||
|
||||
if has_kv_transfer_group():
|
||||
kv_transfer_group = get_kv_transfer_group()
|
||||
kv_transfer_group.register_kv_caches(kv_caches)
|
||||
if self.cross_layers_kv_cache is not None:
|
||||
assert self.cross_layers_attn_backend is not None
|
||||
kv_transfer_group.register_cross_layers_kv_cache(
|
||||
self.cross_layers_kv_cache, self.cross_layers_attn_backend
|
||||
)
|
||||
else:
|
||||
kv_transfer_group.register_kv_caches(kv_caches)
|
||||
kv_transfer_group.set_host_xfer_buffer_ops(copy_kv_blocks)
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
|
||||
@ -11,7 +11,11 @@ from typing import (
|
||||
TYPE_CHECKING, # noqa: UP035
|
||||
)
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention import AttentionBackend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.distributed.kv_transfer import (
|
||||
ensure_kv_transfer_shutdown,
|
||||
get_kv_transfer_group,
|
||||
@ -21,11 +25,13 @@ from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig
|
||||
from vllm.v1.outputs import (
|
||||
EMPTY_MODEL_RUNNER_OUTPUT,
|
||||
KVConnectorOutput,
|
||||
ModelRunnerOutput,
|
||||
)
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
@ -142,3 +148,162 @@ class KVConnectorModelRunnerMixin:
|
||||
if has_kv_transfer_group():
|
||||
return get_kv_transfer_group().get_kv_connector_stats()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def use_uniform_kv_cache(
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
cache_dtype: CacheDType,
|
||||
) -> bool:
|
||||
"""
|
||||
Determines whether a uniform KV layout should be used.
|
||||
A uniform layout means all layers KV caches will share the same
|
||||
underlying tensor, where for a given block number, the respective
|
||||
KV data for all layers will be contiguous.
|
||||
This will allow efficient KV transfer of per-block KV data for all
|
||||
layers at once.
|
||||
Note this layout will only be applied given 3 conditions:
|
||||
1. The KV Cache config contains just a single group where all layers
|
||||
have the same page size.
|
||||
2. A KV connector is configured, and the KV connector instance prefers
|
||||
to use this layout (prefer_cross_layer_blocks() returns True)
|
||||
2. The flash attention backend supports this layout
|
||||
(get_kv_cache_stride_order(True) includes a placement for a
|
||||
num_layers dimension)
|
||||
|
||||
Note that the actual placement of the num_layers dimensions
|
||||
in the unified layers tensors will be determined by the attention
|
||||
backend.
|
||||
Thus, the layers KV data may still not be contiguous per block
|
||||
if the attention backend does not support it.
|
||||
|
||||
Args:
|
||||
attn_groups: The list of attention groups for this model
|
||||
cache_dtype: The KV cache dtype
|
||||
Returns:
|
||||
True if we should use a uniform KV cache layout.
|
||||
"""
|
||||
|
||||
if not has_kv_transfer_group():
|
||||
return False
|
||||
if not get_kv_transfer_group().prefer_cross_layer_blocks:
|
||||
return False
|
||||
|
||||
if len(attn_groups) != 1 or len(attn_groups[0]) != 1:
|
||||
return False
|
||||
|
||||
attn_group = attn_groups[0][0]
|
||||
kv_cache_spec = attn_group.kv_cache_spec
|
||||
if not isinstance(kv_cache_spec, AttentionSpec):
|
||||
return False
|
||||
|
||||
attn_backend = attn_group.backend
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
1234,
|
||||
kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size,
|
||||
cache_dtype_str=cache_dtype,
|
||||
)
|
||||
|
||||
try:
|
||||
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
|
||||
include_num_layers_dimension=True
|
||||
)
|
||||
except (AttributeError, NotImplementedError):
|
||||
return False
|
||||
|
||||
# check that attention backend include a layers dimension
|
||||
return len(kv_cache_stride_order) == len(kv_cache_shape) + 1
|
||||
|
||||
@staticmethod
|
||||
def allocate_uniform_kv_caches(
|
||||
kv_cache_config: KVCacheConfig,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
cache_dtype: CacheDType,
|
||||
device: torch.device,
|
||||
kernel_block_sizes: list[int],
|
||||
) -> tuple[dict[str, torch.Tensor], torch.Tensor, type[AttentionBackend]]:
|
||||
"""
|
||||
Initializes and reshapes KV caches for the simple case where all
|
||||
layers have the same layout.
|
||||
|
||||
This function assumes use_uniform_kv_cache() returned True.
|
||||
|
||||
Args:
|
||||
kv_cache_config: The KV cache config
|
||||
attn_groups: The list of attention groups for this model
|
||||
cache_dtype: The KV cache dtype
|
||||
device: The torch device to allocate on.
|
||||
kernel_block_sizes: The kernel block sizes for each KV cache group.
|
||||
Returns:
|
||||
A tuple (kv_caches, cross_layers_kv_cache, attn_backend) where:
|
||||
kv_caches is a dict mapping between layer names to their
|
||||
corresponding memory buffer for KV cache.
|
||||
cross_layers_kv_cache is the cross layers kv cache tensor
|
||||
attn_backend is the attention backend matching this tensor
|
||||
"""
|
||||
attn_group = attn_groups[0][0]
|
||||
kv_cache_spec = attn_group.kv_cache_spec
|
||||
assert isinstance(kv_cache_spec, AttentionSpec)
|
||||
|
||||
tensor_sizes = set(
|
||||
kv_cache_tensor.size for kv_cache_tensor in kv_cache_config.kv_cache_tensors
|
||||
)
|
||||
assert len(tensor_sizes) == 1
|
||||
tensor_size = tensor_sizes.pop()
|
||||
|
||||
page_size = kv_cache_spec.page_size_bytes
|
||||
assert tensor_size % page_size == 0
|
||||
num_blocks = tensor_size // page_size
|
||||
num_layers = len(kv_cache_config.kv_cache_tensors)
|
||||
total_size = tensor_size * num_layers
|
||||
|
||||
assert len(kernel_block_sizes) == 1
|
||||
kernel_block_size = kernel_block_sizes[0]
|
||||
num_blocks_per_kv_block = kv_cache_spec.block_size // kernel_block_size
|
||||
kernel_num_blocks = num_blocks * num_blocks_per_kv_block
|
||||
|
||||
attn_backend = attn_group.backend
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
kernel_num_blocks,
|
||||
kernel_block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size,
|
||||
cache_dtype_str=cache_dtype,
|
||||
)
|
||||
|
||||
# prepend a num_layers dimension into the shape
|
||||
kv_cache_shape = (num_layers,) + kv_cache_shape
|
||||
|
||||
try:
|
||||
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
|
||||
include_num_layers_dimension=True
|
||||
)
|
||||
assert len(kv_cache_stride_order) == len(kv_cache_shape)
|
||||
except (AttributeError, NotImplementedError):
|
||||
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
|
||||
|
||||
kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
|
||||
|
||||
logger.info("Allocating a cross layer KV cache of shape %s", kv_cache_shape)
|
||||
|
||||
# allocate one contiguous buffer for all layers
|
||||
cross_layers_kv_cache = (
|
||||
torch.zeros(total_size, dtype=torch.int8, device=device)
|
||||
.view(kv_cache_spec.dtype)
|
||||
.view(kv_cache_shape)
|
||||
)
|
||||
|
||||
# Maintain original KV shape view.
|
||||
inv_order = [
|
||||
kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order))
|
||||
]
|
||||
permuted_kv_cache = cross_layers_kv_cache.permute(*inv_order)
|
||||
|
||||
kv_caches = {}
|
||||
for i, kv_cache_tensor in enumerate(kv_cache_config.kv_cache_tensors):
|
||||
tensor = permuted_kv_cache[i]
|
||||
for layer_name in kv_cache_tensor.shared_by:
|
||||
kv_caches[layer_name] = tensor
|
||||
|
||||
return kv_caches, cross_layers_kv_cache, attn_backend
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user