mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 14:56:08 +08:00
Allocate kv_cache with stride order (#16605)
Signed-off-by: shuw <shuw@nvidia.com>
This commit is contained in:
parent
b278911229
commit
9e96f56efb
@ -270,9 +270,10 @@ __global__ void reshape_and_cache_flash_kernel(
|
||||
cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads,
|
||||
// head_size]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int block_stride, const int key_stride, const int value_stride,
|
||||
const int num_heads, const int head_size, const int block_size,
|
||||
const float* k_scale, const float* v_scale) {
|
||||
const int64_t block_stride, const int64_t page_stride,
|
||||
const int64_t head_stride, const int64_t key_stride,
|
||||
const int64_t value_stride, const int num_heads, const int head_size,
|
||||
const int block_size, const float* k_scale, const float* v_scale) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
// NOTE: slot_idx can be -1 if the token is padded
|
||||
@ -288,8 +289,8 @@ __global__ void reshape_and_cache_flash_kernel(
|
||||
const int head_idx = i / head_size;
|
||||
const int head_offset = i % head_size;
|
||||
const int64_t tgt_key_value_idx = block_idx * block_stride +
|
||||
block_offset * num_heads * head_size +
|
||||
head_idx * head_size + head_offset;
|
||||
block_offset * page_stride +
|
||||
head_idx * head_stride + head_offset;
|
||||
scalar_t tgt_key = key[src_key_idx];
|
||||
scalar_t tgt_value = value[src_value_idx];
|
||||
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
||||
@ -396,16 +397,16 @@ void reshape_and_cache(
|
||||
// KV_T is the data type of key and value tensors.
|
||||
// CACHE_T is the stored data type of kv-cache.
|
||||
// KV_DTYPE is the real data type of kv-cache.
|
||||
#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
||||
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
||||
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \
|
||||
value_stride, num_heads, head_size, block_size, \
|
||||
reinterpret_cast<const float*>(k_scale.data_ptr()), \
|
||||
#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
||||
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
||||
slot_mapping.data_ptr<int64_t>(), block_stride, page_stride, \
|
||||
head_stride, key_stride, value_stride, num_heads, head_size, \
|
||||
block_size, reinterpret_cast<const float*>(k_scale.data_ptr()), \
|
||||
reinterpret_cast<const float*>(v_scale.data_ptr()));
|
||||
|
||||
void reshape_and_cache_flash(
|
||||
@ -432,9 +433,11 @@ void reshape_and_cache_flash(
|
||||
int head_size = key.size(2);
|
||||
int block_size = key_cache.size(1);
|
||||
|
||||
int key_stride = key.stride(0);
|
||||
int value_stride = value.stride(0);
|
||||
int block_stride = key_cache.stride(0);
|
||||
int64_t key_stride = key.stride(0);
|
||||
int64_t value_stride = value.stride(0);
|
||||
int64_t block_stride = key_cache.stride(0);
|
||||
int64_t page_stride = key_cache.stride(1);
|
||||
int64_t head_stride = key_cache.stride(2);
|
||||
TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0));
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
|
||||
@ -16,6 +16,7 @@ NUM_LAYERS = [1] # Arbitrary values for testing
|
||||
NUM_HEADS = [8] # Arbitrary values for testing
|
||||
HEAD_SIZES = [64, 80, 120, 256]
|
||||
BLOCK_SIZES = [8, 16, 32]
|
||||
CACHE_LAYOUTS = ["NHD", "HND"]
|
||||
|
||||
# Parameters for MLA tests.
|
||||
KV_LORA_RANKS = [512]
|
||||
@ -220,6 +221,7 @@ def test_reshape_and_cache(
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
||||
@pytest.mark.parametrize("kv_cache_layout", CACHE_LAYOUTS)
|
||||
@torch.inference_mode()
|
||||
def test_reshape_and_cache_flash(
|
||||
kv_cache_factory_flashinfer,
|
||||
@ -232,17 +234,21 @@ def test_reshape_and_cache_flash(
|
||||
seed: int,
|
||||
device: str,
|
||||
kv_cache_dtype: str,
|
||||
kv_cache_layout: str,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
# fp8 conversion requires continugous memory buffer. Reduce the number of
|
||||
# blocks and tokens to consume less memory.
|
||||
num_tokens = num_tokens // 2
|
||||
num_blocks = num_blocks // 2
|
||||
# Create a random slot mapping.
|
||||
num_slots = block_size * num_blocks
|
||||
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping_lst,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
qkv = torch.randn(num_tokens,
|
||||
3,
|
||||
num_heads,
|
||||
@ -261,27 +267,35 @@ def test_reshape_and_cache_flash(
|
||||
kv_cache_dtype,
|
||||
dtype,
|
||||
device=device,
|
||||
cache_layout=kv_cache_layout,
|
||||
)
|
||||
key_cache, value_cache = key_caches[0].contiguous(
|
||||
), value_caches[0].contiguous()
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
del key_caches
|
||||
del value_caches
|
||||
|
||||
k_scale = (key.amax() / 64.0).to(torch.float32)
|
||||
v_scale = (value.amax() / 64.0).to(torch.float32)
|
||||
|
||||
def permute_and_compact(x):
|
||||
y = x if kv_cache_layout == "NHD" else x.permute(0, 2, 1, 3)
|
||||
return y.contiguous()
|
||||
|
||||
key_cache_compact = permute_and_compact(key_cache)
|
||||
value_cache_compact = permute_and_compact(value_cache)
|
||||
|
||||
# Clone the KV caches.
|
||||
if kv_cache_dtype == "fp8":
|
||||
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
||||
ops.convert_fp8(cloned_key_cache, key_cache, k_scale.item(),
|
||||
kv_cache_dtype)
|
||||
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
||||
ops.convert_fp8(cloned_value_cache, value_cache, v_scale.item(),
|
||||
cloned_key_cache = torch.empty_like(key_cache_compact,
|
||||
dtype=torch.float16)
|
||||
ops.convert_fp8(cloned_key_cache, key_cache_compact, k_scale.item(),
|
||||
kv_cache_dtype)
|
||||
cloned_value_cache = torch.empty_like(value_cache_compact,
|
||||
dtype=torch.float16)
|
||||
ops.convert_fp8(cloned_value_cache, value_cache_compact,
|
||||
v_scale.item(), kv_cache_dtype)
|
||||
else:
|
||||
cloned_key_cache = key_cache.clone()
|
||||
cloned_value_cache = value_cache.clone()
|
||||
|
||||
cloned_key_cache = key_cache_compact.clone()
|
||||
cloned_value_cache = value_cache_compact.clone()
|
||||
# Call the reshape_and_cache kernel.
|
||||
opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
|
||||
(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
|
||||
@ -289,16 +303,20 @@ def test_reshape_and_cache_flash(
|
||||
cond=(head_size == HEAD_SIZES[0]))
|
||||
ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
|
||||
slot_mapping, kv_cache_dtype, k_scale, v_scale)
|
||||
key_cache_compact = permute_and_compact(key_cache)
|
||||
value_cache_compact = permute_and_compact(value_cache)
|
||||
|
||||
if kv_cache_dtype == "fp8":
|
||||
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
||||
result_key_cache = torch.empty_like(key_cache_compact,
|
||||
dtype=torch.float16)
|
||||
ops.convert_fp8(result_key_cache,
|
||||
key_cache,
|
||||
key_cache_compact,
|
||||
k_scale.item(),
|
||||
kv_dtype=kv_cache_dtype)
|
||||
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
||||
result_value_cache = torch.empty_like(value_cache_compact,
|
||||
dtype=torch.float16)
|
||||
ops.convert_fp8(result_value_cache,
|
||||
value_cache,
|
||||
value_cache_compact,
|
||||
v_scale.item(),
|
||||
kv_dtype=kv_cache_dtype)
|
||||
|
||||
@ -310,8 +328,12 @@ def test_reshape_and_cache_flash(
|
||||
for i in range(num_tokens):
|
||||
block_idx = block_indicies_lst[i]
|
||||
block_offset = block_offsets_lst[i]
|
||||
cloned_key_cache[block_idx, block_offset, :, :] = key[i]
|
||||
cloned_value_cache[block_idx, block_offset, :, :] = value[i]
|
||||
if kv_cache_layout == "NHD":
|
||||
cloned_key_cache[block_idx, block_offset, :, :] = key[i]
|
||||
cloned_value_cache[block_idx, block_offset, :, :] = value[i]
|
||||
else:
|
||||
cloned_key_cache[block_idx, :, block_offset, :] = key[i]
|
||||
cloned_value_cache[block_idx, :, block_offset, :] = value[i]
|
||||
|
||||
if kv_cache_dtype == "fp8":
|
||||
torch.testing.assert_close(result_key_cache,
|
||||
@ -323,8 +345,8 @@ def test_reshape_and_cache_flash(
|
||||
atol=0.001,
|
||||
rtol=0.1)
|
||||
else:
|
||||
torch.testing.assert_close(key_cache, cloned_key_cache)
|
||||
torch.testing.assert_close(value_cache, cloned_value_cache)
|
||||
torch.testing.assert_close(key_cache_compact, cloned_key_cache)
|
||||
torch.testing.assert_close(value_cache_compact, cloned_value_cache)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("direction", COPYING_DIRECTION)
|
||||
|
||||
@ -77,6 +77,10 @@ class AttentionBackend(ABC):
|
||||
) -> Tuple[int, ...]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_stride_order() -> Tuple[int, ...]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def swap_blocks(
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import dataclasses
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
@ -48,6 +49,9 @@ if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
||||
ModelInputForGPUWithSamplingMetadata)
|
||||
|
||||
FLASHINFER_KV_CACHE_LAYOUT: str = os.getenv("FLASHINFER_KV_CACHE_LAYOUT",
|
||||
"NHD").upper()
|
||||
|
||||
|
||||
class FlashInferBackend(AttentionBackend):
|
||||
|
||||
@ -80,6 +84,14 @@ class FlashInferBackend(AttentionBackend):
|
||||
) -> Tuple[int, ...]:
|
||||
return (num_blocks, 2, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_stride_order() -> Tuple[int, ...]:
|
||||
cache_layout = FLASHINFER_KV_CACHE_LAYOUT
|
||||
assert (cache_layout in ("NHD", "HND"))
|
||||
stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3,
|
||||
2, 4)
|
||||
return stride_order
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
@ -188,6 +200,7 @@ class FlashInferState(AttentionState):
|
||||
self.global_hyperparameters: Optional[PerLayerParameters] = None
|
||||
|
||||
self.vllm_config = self.runner.vllm_config
|
||||
self._kv_cache_layout = None
|
||||
|
||||
def _get_workspace_buffer(self):
|
||||
if self._workspace_buffer is None:
|
||||
@ -197,10 +210,15 @@ class FlashInferState(AttentionState):
|
||||
device=self.runner.device)
|
||||
return self._workspace_buffer
|
||||
|
||||
def get_kv_cache_layout(self):
|
||||
if self._kv_cache_layout is None:
|
||||
self._kv_cache_layout = FLASHINFER_KV_CACHE_LAYOUT
|
||||
return self._kv_cache_layout
|
||||
|
||||
def _get_prefill_wrapper(self):
|
||||
if self._prefill_wrapper is None:
|
||||
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
|
||||
self._get_workspace_buffer(), "NHD")
|
||||
self._get_workspace_buffer(), self.get_kv_cache_layout())
|
||||
return self._prefill_wrapper
|
||||
|
||||
def _get_decode_wrapper(self):
|
||||
@ -213,7 +231,7 @@ class FlashInferState(AttentionState):
|
||||
num_qo_heads // num_kv_heads > 4)
|
||||
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||
self._get_workspace_buffer(),
|
||||
"NHD",
|
||||
self.get_kv_cache_layout(),
|
||||
use_tensor_cores=use_tensor_cores)
|
||||
return self._decode_wrapper
|
||||
|
||||
@ -274,7 +292,8 @@ class FlashInferState(AttentionState):
|
||||
self._graph_decode_wrapper = \
|
||||
CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
|
||||
self._graph_decode_workspace_buffer, _indptr_buffer,
|
||||
self._graph_indices_buffer, _last_page_len_buffer, "NHD",
|
||||
self._graph_indices_buffer, _last_page_len_buffer,
|
||||
self.get_kv_cache_layout(),
|
||||
use_tensor_cores)
|
||||
if self.runner.kv_cache_dtype.startswith("fp8"):
|
||||
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
||||
@ -1005,6 +1024,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
|
||||
prefill_output: Optional[torch.Tensor] = None
|
||||
decode_output: Optional[torch.Tensor] = None
|
||||
stride_order = FlashInferBackend.get_kv_cache_stride_order()
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# We will use flash attention for prefill
|
||||
# when kv_cache is not provided.
|
||||
@ -1036,7 +1056,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
|
||||
prefill_output = prefill_meta.prefill_wrapper.run(
|
||||
query,
|
||||
kv_cache,
|
||||
kv_cache.permute(*stride_order),
|
||||
k_scale=layer._k_scale_float,
|
||||
v_scale=layer._v_scale_float,
|
||||
)
|
||||
@ -1051,7 +1071,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
|
||||
decode_output = decode_meta.decode_wrapper.run(
|
||||
decode_query,
|
||||
kv_cache,
|
||||
kv_cache.permute(*stride_order),
|
||||
k_scale=layer._k_scale_float,
|
||||
v_scale=layer._v_scale_float,
|
||||
)
|
||||
|
||||
@ -765,21 +765,28 @@ def create_kv_caches_with_random_flash(
|
||||
model_dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
seed: Optional[int] = None,
|
||||
device: Optional[str] = "cuda",
|
||||
cache_layout: Optional[str] = "NHD",
|
||||
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
|
||||
from vllm.platforms import current_platform
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
|
||||
key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
|
||||
generic_kv_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
|
||||
assert cache_layout in ("NHD", "HND")
|
||||
stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, 2,
|
||||
4)
|
||||
|
||||
kv_cache_allocation_shape = tuple(generic_kv_cache_shape[i]
|
||||
for i in stride_order)
|
||||
scale = head_size**-0.5
|
||||
|
||||
key_caches: list[torch.Tensor] = []
|
||||
value_caches: list[torch.Tensor] = []
|
||||
|
||||
for _ in range(num_layers):
|
||||
key_value_cache = torch.empty(size=key_value_cache_shape,
|
||||
key_value_cache = torch.empty(size=kv_cache_allocation_shape,
|
||||
dtype=torch_dtype,
|
||||
device=device)
|
||||
device=device).permute(*stride_order)
|
||||
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
|
||||
key_value_cache.uniform_(-scale, scale)
|
||||
elif cache_dtype == 'fp8':
|
||||
|
||||
@ -71,19 +71,32 @@ class CacheEngine:
|
||||
device: str,
|
||||
) -> List[torch.Tensor]:
|
||||
"""Allocates KV cache on the specified device."""
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
kv_cache_generic_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
|
||||
pin_memory = is_pin_memory_available() if device == "cpu" else False
|
||||
kv_cache: List[torch.Tensor] = []
|
||||
try:
|
||||
kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order(
|
||||
)
|
||||
except (AttributeError, NotImplementedError):
|
||||
kv_cache_stride_order = tuple(range(len(kv_cache_generic_shape)))
|
||||
|
||||
# The allocation respects the backend-defined stride order to ensure
|
||||
# the semantic remains consistent for each backend. We first obtain the
|
||||
# generic kv cache shape and then permute it according to the stride
|
||||
# order which could result in a non-contiguous tensor.
|
||||
kv_cache_allocation_shape = tuple(kv_cache_generic_shape[i]
|
||||
for i in kv_cache_stride_order)
|
||||
|
||||
for _ in range(self.num_attention_layers):
|
||||
# null block in CpuGpuBlockAllocator requires at least that
|
||||
# block to be zeroed-out.
|
||||
# We zero-out everything for simplicity.
|
||||
layer_kv_cache = torch.zeros(kv_cache_shape,
|
||||
dtype=self.dtype,
|
||||
pin_memory=pin_memory,
|
||||
device=device)
|
||||
layer_kv_cache = torch.zeros(
|
||||
kv_cache_allocation_shape,
|
||||
dtype=self.dtype,
|
||||
pin_memory=pin_memory,
|
||||
device=device).permute(*kv_cache_stride_order)
|
||||
|
||||
# view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases
|
||||
# when entry_shape is higher than 1D
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user