Allocate kv_cache with stride order (#16605)

Signed-off-by: shuw <shuw@nvidia.com>
This commit is contained in:
Shu Wang 2025-04-26 00:03:31 -05:00 committed by GitHub
parent b278911229
commit 9e96f56efb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 119 additions and 50 deletions

View File

@ -270,9 +270,10 @@ __global__ void reshape_and_cache_flash_kernel(
cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads,
// head_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int block_stride, const int key_stride, const int value_stride,
const int num_heads, const int head_size, const int block_size,
const float* k_scale, const float* v_scale) {
const int64_t block_stride, const int64_t page_stride,
const int64_t head_stride, const int64_t key_stride,
const int64_t value_stride, const int num_heads, const int head_size,
const int block_size, const float* k_scale, const float* v_scale) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
// NOTE: slot_idx can be -1 if the token is padded
@ -288,8 +289,8 @@ __global__ void reshape_and_cache_flash_kernel(
const int head_idx = i / head_size;
const int head_offset = i % head_size;
const int64_t tgt_key_value_idx = block_idx * block_stride +
block_offset * num_heads * head_size +
head_idx * head_size + head_offset;
block_offset * page_stride +
head_idx * head_stride + head_offset;
scalar_t tgt_key = key[src_key_idx];
scalar_t tgt_value = value[src_value_idx];
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
@ -396,16 +397,16 @@ void reshape_and_cache(
// KV_T is the data type of key and value tensors.
// CACHE_T is the stored data type of kv-cache.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \
vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(key.data_ptr()), \
reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \
value_stride, num_heads, head_size, block_size, \
reinterpret_cast<const float*>(k_scale.data_ptr()), \
#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \
vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(key.data_ptr()), \
reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, page_stride, \
head_stride, key_stride, value_stride, num_heads, head_size, \
block_size, reinterpret_cast<const float*>(k_scale.data_ptr()), \
reinterpret_cast<const float*>(v_scale.data_ptr()));
void reshape_and_cache_flash(
@ -432,9 +433,11 @@ void reshape_and_cache_flash(
int head_size = key.size(2);
int block_size = key_cache.size(1);
int key_stride = key.stride(0);
int value_stride = value.stride(0);
int block_stride = key_cache.stride(0);
int64_t key_stride = key.stride(0);
int64_t value_stride = value.stride(0);
int64_t block_stride = key_cache.stride(0);
int64_t page_stride = key_cache.stride(1);
int64_t head_stride = key_cache.stride(2);
TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0));
dim3 grid(num_tokens);

View File

@ -16,6 +16,7 @@ NUM_LAYERS = [1] # Arbitrary values for testing
NUM_HEADS = [8] # Arbitrary values for testing
HEAD_SIZES = [64, 80, 120, 256]
BLOCK_SIZES = [8, 16, 32]
CACHE_LAYOUTS = ["NHD", "HND"]
# Parameters for MLA tests.
KV_LORA_RANKS = [512]
@ -220,6 +221,7 @@ def test_reshape_and_cache(
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@pytest.mark.parametrize("kv_cache_layout", CACHE_LAYOUTS)
@torch.inference_mode()
def test_reshape_and_cache_flash(
kv_cache_factory_flashinfer,
@ -232,17 +234,21 @@ def test_reshape_and_cache_flash(
seed: int,
device: str,
kv_cache_dtype: str,
kv_cache_layout: str,
) -> None:
current_platform.seed_everything(seed)
torch.set_default_device(device)
# fp8 conversion requires continugous memory buffer. Reduce the number of
# blocks and tokens to consume less memory.
num_tokens = num_tokens // 2
num_blocks = num_blocks // 2
# Create a random slot mapping.
num_slots = block_size * num_blocks
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping_lst,
dtype=torch.long,
device=device)
qkv = torch.randn(num_tokens,
3,
num_heads,
@ -261,27 +267,35 @@ def test_reshape_and_cache_flash(
kv_cache_dtype,
dtype,
device=device,
cache_layout=kv_cache_layout,
)
key_cache, value_cache = key_caches[0].contiguous(
), value_caches[0].contiguous()
key_cache, value_cache = key_caches[0], value_caches[0]
del key_caches
del value_caches
k_scale = (key.amax() / 64.0).to(torch.float32)
v_scale = (value.amax() / 64.0).to(torch.float32)
def permute_and_compact(x):
y = x if kv_cache_layout == "NHD" else x.permute(0, 2, 1, 3)
return y.contiguous()
key_cache_compact = permute_and_compact(key_cache)
value_cache_compact = permute_and_compact(value_cache)
# Clone the KV caches.
if kv_cache_dtype == "fp8":
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
ops.convert_fp8(cloned_key_cache, key_cache, k_scale.item(),
kv_cache_dtype)
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
ops.convert_fp8(cloned_value_cache, value_cache, v_scale.item(),
cloned_key_cache = torch.empty_like(key_cache_compact,
dtype=torch.float16)
ops.convert_fp8(cloned_key_cache, key_cache_compact, k_scale.item(),
kv_cache_dtype)
cloned_value_cache = torch.empty_like(value_cache_compact,
dtype=torch.float16)
ops.convert_fp8(cloned_value_cache, value_cache_compact,
v_scale.item(), kv_cache_dtype)
else:
cloned_key_cache = key_cache.clone()
cloned_value_cache = value_cache.clone()
cloned_key_cache = key_cache_compact.clone()
cloned_value_cache = value_cache_compact.clone()
# Call the reshape_and_cache kernel.
opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
@ -289,16 +303,20 @@ def test_reshape_and_cache_flash(
cond=(head_size == HEAD_SIZES[0]))
ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype, k_scale, v_scale)
key_cache_compact = permute_and_compact(key_cache)
value_cache_compact = permute_and_compact(value_cache)
if kv_cache_dtype == "fp8":
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
result_key_cache = torch.empty_like(key_cache_compact,
dtype=torch.float16)
ops.convert_fp8(result_key_cache,
key_cache,
key_cache_compact,
k_scale.item(),
kv_dtype=kv_cache_dtype)
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
result_value_cache = torch.empty_like(value_cache_compact,
dtype=torch.float16)
ops.convert_fp8(result_value_cache,
value_cache,
value_cache_compact,
v_scale.item(),
kv_dtype=kv_cache_dtype)
@ -310,8 +328,12 @@ def test_reshape_and_cache_flash(
for i in range(num_tokens):
block_idx = block_indicies_lst[i]
block_offset = block_offsets_lst[i]
cloned_key_cache[block_idx, block_offset, :, :] = key[i]
cloned_value_cache[block_idx, block_offset, :, :] = value[i]
if kv_cache_layout == "NHD":
cloned_key_cache[block_idx, block_offset, :, :] = key[i]
cloned_value_cache[block_idx, block_offset, :, :] = value[i]
else:
cloned_key_cache[block_idx, :, block_offset, :] = key[i]
cloned_value_cache[block_idx, :, block_offset, :] = value[i]
if kv_cache_dtype == "fp8":
torch.testing.assert_close(result_key_cache,
@ -323,8 +345,8 @@ def test_reshape_and_cache_flash(
atol=0.001,
rtol=0.1)
else:
torch.testing.assert_close(key_cache, cloned_key_cache)
torch.testing.assert_close(value_cache, cloned_value_cache)
torch.testing.assert_close(key_cache_compact, cloned_key_cache)
torch.testing.assert_close(value_cache_compact, cloned_value_cache)
@pytest.mark.parametrize("direction", COPYING_DIRECTION)

View File

@ -77,6 +77,10 @@ class AttentionBackend(ABC):
) -> Tuple[int, ...]:
raise NotImplementedError
@staticmethod
def get_kv_cache_stride_order() -> Tuple[int, ...]:
raise NotImplementedError
@staticmethod
@abstractmethod
def swap_blocks(

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
import dataclasses
import os
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
@ -48,6 +49,9 @@ if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata)
FLASHINFER_KV_CACHE_LAYOUT: str = os.getenv("FLASHINFER_KV_CACHE_LAYOUT",
"NHD").upper()
class FlashInferBackend(AttentionBackend):
@ -80,6 +84,14 @@ class FlashInferBackend(AttentionBackend):
) -> Tuple[int, ...]:
return (num_blocks, 2, block_size, num_kv_heads, head_size)
@staticmethod
def get_kv_cache_stride_order() -> Tuple[int, ...]:
cache_layout = FLASHINFER_KV_CACHE_LAYOUT
assert (cache_layout in ("NHD", "HND"))
stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3,
2, 4)
return stride_order
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
@ -188,6 +200,7 @@ class FlashInferState(AttentionState):
self.global_hyperparameters: Optional[PerLayerParameters] = None
self.vllm_config = self.runner.vllm_config
self._kv_cache_layout = None
def _get_workspace_buffer(self):
if self._workspace_buffer is None:
@ -197,10 +210,15 @@ class FlashInferState(AttentionState):
device=self.runner.device)
return self._workspace_buffer
def get_kv_cache_layout(self):
if self._kv_cache_layout is None:
self._kv_cache_layout = FLASHINFER_KV_CACHE_LAYOUT
return self._kv_cache_layout
def _get_prefill_wrapper(self):
if self._prefill_wrapper is None:
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
self._get_workspace_buffer(), "NHD")
self._get_workspace_buffer(), self.get_kv_cache_layout())
return self._prefill_wrapper
def _get_decode_wrapper(self):
@ -213,7 +231,7 @@ class FlashInferState(AttentionState):
num_qo_heads // num_kv_heads > 4)
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self._get_workspace_buffer(),
"NHD",
self.get_kv_cache_layout(),
use_tensor_cores=use_tensor_cores)
return self._decode_wrapper
@ -274,7 +292,8 @@ class FlashInferState(AttentionState):
self._graph_decode_wrapper = \
CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
self._graph_decode_workspace_buffer, _indptr_buffer,
self._graph_indices_buffer, _last_page_len_buffer, "NHD",
self._graph_indices_buffer, _last_page_len_buffer,
self.get_kv_cache_layout(),
use_tensor_cores)
if self.runner.kv_cache_dtype.startswith("fp8"):
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
@ -1005,6 +1024,7 @@ class FlashInferImpl(AttentionImpl):
prefill_output: Optional[torch.Tensor] = None
decode_output: Optional[torch.Tensor] = None
stride_order = FlashInferBackend.get_kv_cache_stride_order()
if prefill_meta := attn_metadata.prefill_metadata:
# We will use flash attention for prefill
# when kv_cache is not provided.
@ -1036,7 +1056,7 @@ class FlashInferImpl(AttentionImpl):
prefill_output = prefill_meta.prefill_wrapper.run(
query,
kv_cache,
kv_cache.permute(*stride_order),
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
)
@ -1051,7 +1071,7 @@ class FlashInferImpl(AttentionImpl):
decode_output = decode_meta.decode_wrapper.run(
decode_query,
kv_cache,
kv_cache.permute(*stride_order),
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
)

View File

@ -765,21 +765,28 @@ def create_kv_caches_with_random_flash(
model_dtype: Optional[Union[str, torch.dtype]] = None,
seed: Optional[int] = None,
device: Optional[str] = "cuda",
cache_layout: Optional[str] = "NHD",
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
from vllm.platforms import current_platform
current_platform.seed_everything(seed)
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
generic_kv_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
assert cache_layout in ("NHD", "HND")
stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, 2,
4)
kv_cache_allocation_shape = tuple(generic_kv_cache_shape[i]
for i in stride_order)
scale = head_size**-0.5
key_caches: list[torch.Tensor] = []
value_caches: list[torch.Tensor] = []
for _ in range(num_layers):
key_value_cache = torch.empty(size=key_value_cache_shape,
key_value_cache = torch.empty(size=kv_cache_allocation_shape,
dtype=torch_dtype,
device=device)
device=device).permute(*stride_order)
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
key_value_cache.uniform_(-scale, scale)
elif cache_dtype == 'fp8':

View File

@ -71,19 +71,32 @@ class CacheEngine:
device: str,
) -> List[torch.Tensor]:
"""Allocates KV cache on the specified device."""
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
kv_cache_generic_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
pin_memory = is_pin_memory_available() if device == "cpu" else False
kv_cache: List[torch.Tensor] = []
try:
kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order(
)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(kv_cache_generic_shape)))
# The allocation respects the backend-defined stride order to ensure
# the semantic remains consistent for each backend. We first obtain the
# generic kv cache shape and then permute it according to the stride
# order which could result in a non-contiguous tensor.
kv_cache_allocation_shape = tuple(kv_cache_generic_shape[i]
for i in kv_cache_stride_order)
for _ in range(self.num_attention_layers):
# null block in CpuGpuBlockAllocator requires at least that
# block to be zeroed-out.
# We zero-out everything for simplicity.
layer_kv_cache = torch.zeros(kv_cache_shape,
dtype=self.dtype,
pin_memory=pin_memory,
device=device)
layer_kv_cache = torch.zeros(
kv_cache_allocation_shape,
dtype=self.dtype,
pin_memory=pin_memory,
device=device).permute(*kv_cache_stride_order)
# view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases
# when entry_shape is higher than 1D