From 75e94309e8d8919e0daea041f6cd81a4b8c09060 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 4 Feb 2025 21:22:24 -0500 Subject: [PATCH] [Perf] Mem align KV caches for CUDA devices (MLA perf improvement) (#12676) Signed-off-by: simon-mo Signed-off-by: Lucas Wilkinson Signed-off-by: Lucas Wilkinson Signed-off-by: Lucas Wilkinson Co-authored-by: simon-mo --- csrc/cache.h | 3 + csrc/cache_kernels.cu | 82 +++++- csrc/torch_bindings.cpp | 4 + tests/kernels/test_cache.py | 262 ++++++++++++++++++ vllm/_custom_ops.py | 5 + vllm/attention/backends/triton_mla.py | 5 +- vllm/attention/ops/triton_decode_attention.py | 16 +- vllm/envs.py | 10 + vllm/utils.py | 10 + vllm/worker/cache_engine.py | 66 ++++- 10 files changed, 429 insertions(+), 34 deletions(-) diff --git a/csrc/cache.h b/csrc/cache.h index 55ed30bd8ce4..cf4a65c29055 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -15,6 +15,9 @@ void copy_blocks(std::vector const& key_caches, std::vector const& value_caches, const torch::Tensor& block_mapping); +void copy_blocks_mla(std::vector const& kv_caches, + const torch::Tensor& block_mapping); + void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 23a46b6ed8ad..0960888d1f75 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -46,7 +46,10 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst, char* src_ptr = static_cast(src.data_ptr()); char* dst_ptr = static_cast(dst.data_ptr()); - const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); + // We use the stride instead of numel in case the cache is padded for memory + // alignment reasons, we assume the blocks data (inclusive of any padding) + // is contiguous in memory + const int64_t block_size_in_bytes = src.element_size() * src.stride(0); const at::cuda::OptionalCUDAGuard device_guard( src_device.is_cuda() ? src_device : dst_device); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -93,6 +96,24 @@ __global__ void copy_blocks_kernel(int64_t* key_cache_ptrs, } } +// Kernel for MLA, which works on a single joint kv_cache +// Grid: (num_layers, num_pairs) +template +__global__ void copy_blocks_mla_kernel( + int64_t* cache_ptrs, const int64_t* __restrict__ block_mapping, + const int mem_footprint_per_block) { + const int layer_idx = blockIdx.x; + const int pair_idx = blockIdx.y; + scalar_t* cache = reinterpret_cast(cache_ptrs[layer_idx]); + int64_t src_block = block_mapping[2 * pair_idx]; + int64_t dst_block = block_mapping[2 * pair_idx + 1]; + int64_t src_offset = src_block * mem_footprint_per_block; + int64_t dst_offset = dst_block * mem_footprint_per_block; + for (int i = threadIdx.x; i < mem_footprint_per_block; i += blockDim.x) { + cache[dst_offset + i] = cache[src_offset + i]; + } +} + } // namespace vllm // Note: the key_caches and value_caches vectors are constant but @@ -147,6 +168,42 @@ void copy_blocks(std::vector const& key_caches, })); } +// copy blocks kernel for MLA (assumes a joint KV-cache) +void copy_blocks_mla(std::vector const& kv_caches, + const torch::Tensor& block_mapping) { + int num_layers = kv_caches.size(); + if (num_layers == 0) { + return; + } + torch::Device cache_device = kv_caches[0].device(); + TORCH_CHECK(cache_device.is_cuda(), "kv_cache must be on CUDA"); + + std::vector cache_ptrs(num_layers); + for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) { + cache_ptrs[layer_idx] = + reinterpret_cast(kv_caches[layer_idx].data_ptr()); + } + torch::Tensor cache_ptrs_tensor = + torch::from_blob(cache_ptrs.data(), {num_layers}, torch::kInt64) + .to(cache_device); + + int num_pairs = block_mapping.size(0); + // We use the stride instead of numel in case the cache is padded for memory + // alignment reasons, we assume the blocks data (inclusive of any padding) + // is contiguous in memory + int mem_footprint_per_block = kv_caches[0].stride(0); + dim3 grid(num_layers, num_pairs); + dim3 block(std::min(1024, mem_footprint_per_block)); + const at::cuda::OptionalCUDAGuard device_guard(cache_device); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES( + kv_caches[0].scalar_type(), "copy_blocks_mla_kernel", ([&] { + vllm::copy_blocks_mla_kernel<<>>( + cache_ptrs_tensor.data_ptr(), + block_mapping.data_ptr(), mem_footprint_per_block); + })); +} + namespace vllm { template @@ -254,6 +311,7 @@ __global__ void concat_and_cache_mla_kernel( // + pe_dim)] const int64_t* __restrict__ slot_mapping, // [num_tokens] const int block_stride, // + const int entry_stride, // const int kv_c_stride, // const int k_pe_stride, // const int kv_lora_rank, // @@ -274,9 +332,8 @@ __global__ void concat_and_cache_mla_kernel( int src_stride, int dst_stride, int size, int offset) { for (int i = threadIdx.x; i < size; i += blockDim.x) { const int64_t src_idx = token_idx * src_stride + i; - const int64_t dst_idx = block_idx * block_stride + - block_offset * (kv_lora_rank + pe_dim) + i + - offset; + const int64_t dst_idx = + block_idx * block_stride + block_offset * entry_stride + i + offset; if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { dst[dst_idx] = src[src_idx]; } else { @@ -391,14 +448,14 @@ void reshape_and_cache_flash( // KV_T is the stored data type of kv-cache. // CACHE_T is the data type of key and value tensors. // KV_DTYPE is the real data type of kv-cache. -#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \ - vllm::concat_and_cache_mla_kernel \ - <<>>( \ - reinterpret_cast(kv_c.data_ptr()), \ - reinterpret_cast(k_pe.data_ptr()), \ - reinterpret_cast(kv_cache.data_ptr()), \ - slot_mapping.data_ptr(), block_stride, kv_c_stride, \ - k_pe_stride, kv_lora_rank, pe_dim, block_size, \ +#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \ + vllm::concat_and_cache_mla_kernel \ + <<>>( \ + reinterpret_cast(kv_c.data_ptr()), \ + reinterpret_cast(k_pe.data_ptr()), \ + reinterpret_cast(kv_cache.data_ptr()), \ + slot_mapping.data_ptr(), block_stride, entry_stride, \ + kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \ reinterpret_cast(scale.data_ptr())); void concat_and_cache_mla( @@ -428,6 +485,7 @@ void concat_and_cache_mla( int kv_c_stride = kv_c.stride(0); int k_pe_stride = k_pe.stride(0); int block_stride = kv_cache.stride(0); + int entry_stride = kv_cache.stride(1); dim3 grid(num_tokens); dim3 block(std::min(kv_lora_rank, 512)); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 186e9c0e81b7..c03806f430a7 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -450,6 +450,10 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { "Tensor block_mapping) -> ()"); cache_ops.impl("copy_blocks", torch::kCUDA, ©_blocks); + cache_ops.def( + "copy_blocks_mla(Tensor(a!)[] kv_caches, Tensor block_mapping) -> ()"); + cache_ops.impl("copy_blocks_mla", torch::kCUDA, ©_blocks_mla); + // Reshape the key and value tensors and cache them. cache_ops.def( "reshape_and_cache(Tensor key, Tensor value," diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 6f909b6803d3..21c02c5de35c 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -9,6 +9,7 @@ import torch from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck from vllm import _custom_ops as ops from vllm.platforms import current_platform +from vllm.utils import align_to_256bytes COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] DTYPES = [torch.half, torch.bfloat16, torch.float] @@ -18,6 +19,13 @@ NUM_HEADS = [8] # Arbitrary values for testing HEAD_SIZES = [64, 80, 120, 256] BLOCK_SIZES = [8, 16, 32] +# Parameters for MLA tests. +KV_LORA_RANKS = [512] +QK_ROPE_HEAD_DIMS = [64] +NUM_TOKENS_MLA = [42] +BLOCK_SIZES_MLA = [16] +NUM_BLOCKS_MLA = [8] + # Arbitrary values for testing # don't make it too large. e.g. [1024, 36000] will OOM NUM_BLOCKS = [1024, 10000] @@ -432,3 +440,257 @@ def test_fp8_e4m3_conversion( ops.convert_fp8(converted_cache, cache_fp8) torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1) + + +def _create_mla_cache( + num_blocks: int, + block_size: int, + entry_size: int, + dtype: torch.dtype, + kv_cache_dtype: str, + device: str, + align_cache: bool, +) -> torch.Tensor: + cache_dtype = torch.uint8 if kv_cache_dtype == "fp8" else dtype + + if align_cache: + alloc_entry_size = align_to_256bytes(entry_size, cache_dtype) + alloc_shape = (num_blocks, block_size, alloc_entry_size) + cache_full = torch.zeros(alloc_shape, dtype=cache_dtype, device=device) + cache = cache_full[..., :entry_size] + else: + cache = torch.zeros(num_blocks, + block_size, + entry_size, + dtype=cache_dtype, + device=device) + return cache + + +def _fill_mla_cache(cache: torch.Tensor, kv_cache_dtype: str): + rand_dtype = torch.float16 if kv_cache_dtype == "fp8" else cache.dtype + + vals = torch.randn(*cache.shape, device=cache.device, dtype=rand_dtype) + if kv_cache_dtype == "fp8": + temp = torch.zeros_like(cache) + ops.convert_fp8(temp, vals, 1.0, kv_dtype=kv_cache_dtype) + vals = temp + cache.copy_(vals) + + +@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS) +@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS) +@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA) +@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) +@pytest.mark.parametrize("align_cache", [False]) +@torch.inference_mode() +def test_concat_and_cache_mla( + kv_lora_rank: int, + qk_rope_head_dim: int, + num_tokens: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + seed: int, + device: str, + kv_cache_dtype: str, + align_cache: bool, +) -> None: + current_platform.seed_everything(seed) + torch.set_default_device(device) + + total_slots = num_blocks * block_size + slot_mapping_lst = random.sample(range(total_slots), num_tokens) + slot_mapping = torch.tensor(slot_mapping_lst, + dtype=torch.long, + device=device) + + kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device) + k_pe = torch.randn(num_tokens, + qk_rope_head_dim, + dtype=dtype, + device=device) + entry_size = kv_lora_rank + qk_rope_head_dim + + scale = torch.tensor(0.1, dtype=torch.float32, device=device) + kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, + kv_cache_dtype, device, align_cache) + ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device) + + for i in range(num_tokens): + slot = slot_mapping[i].item() + block_idx = slot // block_size + block_offset = slot % block_size + ref_temp[block_idx, block_offset, :kv_lora_rank] = kv_c[i] + ref_temp[block_idx, block_offset, kv_lora_rank:] = k_pe[i] + + if kv_cache_dtype == "fp8": + ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype) + ops.convert_fp8(ref_kv_cache, + ref_temp, + scale.item(), + kv_dtype=kv_cache_dtype) + else: + ref_kv_cache = ref_temp + + opcheck( + torch.ops._C_cache_ops.concat_and_cache_mla, + (kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale), + test_utils=DEFAULT_OPCHECK_TEST_UTILS, + ) + + ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, + kv_cache_dtype, scale) + + if kv_cache_dtype == "fp8": + result_temp = torch.empty_like(kv_cache, dtype=torch.float16) + ops.convert_fp8(result_temp, + kv_cache.contiguous(), + scale.item(), + kv_dtype=kv_cache_dtype) + expected_temp = torch.empty_like(ref_kv_cache, dtype=torch.float16) + ops.convert_fp8(expected_temp, + ref_kv_cache, + scale.item(), + kv_dtype=kv_cache_dtype) + torch.testing.assert_close(result_temp, + expected_temp, + atol=0.001, + rtol=0.1) + else: + torch.testing.assert_close(kv_cache, ref_kv_cache) + + +@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS) +@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS) +@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA) +@pytest.mark.parametrize("num_layers", NUM_LAYERS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) +@pytest.mark.parametrize("align_cache", [False, True]) +@torch.inference_mode() +def test_copy_blocks_mla( + kv_lora_rank: int, + qk_rope_head_dim: int, + block_size: int, + num_blocks: int, + num_layers: int, + dtype: torch.dtype, + seed: int, + device: str, + kv_cache_dtype: str, + align_cache: bool, +) -> None: + current_platform.seed_everything(seed) + torch.set_default_device(device) + + entry_size = kv_lora_rank + qk_rope_head_dim + + kv_caches = [] + for _ in range(num_layers): + kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, + kv_cache_dtype, device, align_cache) + _fill_mla_cache(kv_cache, kv_cache_dtype=kv_cache_dtype) + kv_caches.append(kv_cache) + + ref_caches = [kv_cache.clone() for kv_cache in kv_caches] + + num_mappings = min(2, num_blocks // 2) + src_blocks = random.sample(range(num_blocks), num_mappings) + remaining = list(set(range(num_blocks)) - set(src_blocks)) + dst_blocks = random.sample(remaining, 2 * num_mappings) + block_mapping = [] + for i in range(num_mappings): + src = src_blocks[i] + dst1 = dst_blocks[2 * i] + dst2 = dst_blocks[2 * i + 1] + block_mapping.append((src, dst1)) + block_mapping.append((src, dst2)) + block_mapping_tensor = torch.tensor(block_mapping, + dtype=torch.int64, + device=device).view(-1, 2) + + for src, dst in block_mapping: + for ref_cache in ref_caches: + ref_cache[dst].copy_(ref_cache[src]) + + opcheck( + torch.ops._C_cache_ops.copy_blocks_mla, + (kv_caches, block_mapping_tensor), + test_utils=DEFAULT_OPCHECK_TEST_UTILS, + ) + ops.copy_blocks_mla(kv_caches, block_mapping_tensor) + + for kv_cache, ref_cache in zip(kv_caches, ref_caches): + torch.testing.assert_close(kv_cache, ref_cache) + + +@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS) +@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS) +@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) +@pytest.mark.parametrize("align_cache", [False, True]) +@torch.inference_mode() +def test_swap_blocks_mla( + kv_lora_rank: int, + qk_rope_head_dim: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + seed: int, + device: str, + kv_cache_dtype: str, + align_cache: bool, +) -> None: + current_platform.seed_everything(seed) + torch.set_default_device(device) + + entry_size = kv_lora_rank + qk_rope_head_dim + + src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, + kv_cache_dtype, device, align_cache) + dst_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, + kv_cache_dtype, device, align_cache) + + _fill_mla_cache(src_cache, kv_cache_dtype) + _fill_mla_cache(dst_cache, kv_cache_dtype) + + src_cache_clone = src_cache.clone() + + num_mappings = min(2, num_blocks // 2) + src_blocks = random.sample(range(num_blocks), num_mappings) + remaining_blocks = list(set(range(num_blocks)) - set(src_blocks)) + dst_blocks = random.sample(remaining_blocks, num_mappings) + block_mapping = list(zip(src_blocks, dst_blocks)) + block_mapping_tensor = torch.tensor(block_mapping, + dtype=torch.int64, + device="cpu").view(-1, 2) + + opcheck( + torch.ops._C_cache_ops.swap_blocks, + (src_cache, dst_cache, block_mapping_tensor), + test_utils=DEFAULT_OPCHECK_TEST_UTILS, + cond=(kv_lora_rank == KV_LORA_RANKS[0] + and qk_rope_head_dim == QK_ROPE_HEAD_DIMS[0]), + ) + + ops.swap_blocks(src_cache, dst_cache, block_mapping_tensor) + + for src, dst in block_mapping: + torch.testing.assert_close( + src_cache_clone[src].cpu(), + dst_cache[dst].cpu(), + msg=f"Block {src} from src should have been swapped to block " + f"{dst} in dst_cache.") diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index bdc9a6a33df0..a68235016767 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1037,6 +1037,11 @@ def copy_blocks(key_caches: List[torch.Tensor], torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping) +def copy_blocks_mla(kv_caches: List[torch.Tensor], + block_mapping: torch.Tensor) -> None: + torch.ops._C_cache_ops.copy_blocks_mla(kv_caches, block_mapping) + + def swap_blocks(src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor) -> None: torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping) diff --git a/vllm/attention/backends/triton_mla.py b/vllm/attention/backends/triton_mla.py index 20d7ef0fa88e..9a1984a931b5 100644 --- a/vllm/attention/backends/triton_mla.py +++ b/vllm/attention/backends/triton_mla.py @@ -26,7 +26,6 @@ from vllm.attention.backends.mla.utils import MLACommonImpl, MLACommonMetadata from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, compute_slot_mapping_start_idx, is_block_tables_empty) -from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.triton_decode_attention import decode_attention_fwd from vllm.utils import async_tensor_h2d, make_tensor_with_pad @@ -72,14 +71,14 @@ class TritonMLABackend(AttentionBackend): dst_kv_cache: torch.Tensor, src_to_dst: torch.Tensor, ) -> None: - PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], src_to_dists: torch.Tensor, ) -> None: - PagedAttention.copy_blocks(kv_caches, src_to_dists) + ops.copy_blocks_mla(kv_caches, src_to_dists) @staticmethod def get_supported_head_sizes() -> List[int]: diff --git a/vllm/attention/ops/triton_decode_attention.py b/vllm/attention/ops/triton_decode_attention.py index ec5ec4ce6e6b..057fccb5e598 100644 --- a/vllm/attention/ops/triton_decode_attention.py +++ b/vllm/attention/ops/triton_decode_attention.py @@ -204,10 +204,10 @@ def _decode_att_m_fwd( Req_to_tokens.stride(0), q.stride(0), q.stride(1), - k_buffer.stride(-2), - k_buffer.stride(-1), - v_buffer.stride(-2), - v_buffer.stride(-1), + k_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + k_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + v_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + v_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) att_out.stride(0), att_out.stride(1), att_out.stride(2), @@ -438,10 +438,10 @@ def _decode_grouped_att_m_fwd( Req_to_tokens.stride(0), q.stride(0), q.stride(1), - k_buffer.stride(-2), - k_buffer.stride(-1), - v_buffer.stride(-2), - v_buffer.stride(-1), + k_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + k_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + v_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) + v_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM) att_out.stride(0), att_out.stride(1), att_out.stride(2), diff --git a/vllm/envs.py b/vllm/envs.py index 5018f6deb7f4..2c731eda7836 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -82,6 +82,7 @@ if TYPE_CHECKING: VLLM_MLA_DISABLE: bool = False VLLM_MLA_PERFORM_MATRIX_ABSORPTION: bool = True VLLM_MLA_DISABLE_REQUANTIZATION: bool = False + VLLM_MLA_CUDA_MEM_ALIGN_KV_CACHE: bool = True VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False @@ -539,6 +540,15 @@ environment_variables: Dict[str, Callable[[], Any]] = { "VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON": lambda: bool(int(os.getenv("VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0")) ), + + # When on a Nvidia GPU aligns single entries (within a page) so they are 256 + # byte aligned for better performance, this increases the memory usage of + # the cache. Currently this only affects MLA that results in non-256 + # byte aligned entries. This matches the alignment the CUDA runtime uses + # for all allocations. Currently this primarily affects MLA, for most other + # models the alignment is already naturally aligned to 256 bytes. + "VLLM_CUDA_MEM_ALIGN_KV_CACHE": + lambda: bool(int(os.getenv("VLLM_CUDA_MEM_ALIGN_KV_CACHE", "1"))), } # end-env-vars-definition diff --git a/vllm/utils.py b/vllm/utils.py index a2b53fcf252d..8b9269598757 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -563,6 +563,10 @@ def cdiv(a: int, b: int) -> int: return -(a // -b) +def round_up(x: int, y: int) -> int: + return ((x + y - 1) // y) * y + + def _generate_random_fp8( tensor: torch.Tensor, low: float, @@ -794,6 +798,12 @@ def get_dtype_size(dtype: torch.dtype) -> int: return torch.tensor([], dtype=dtype).element_size() +def align_to_256bytes(extent: int, dtype: torch.dtype) -> int: + dtype_size = get_dtype_size(dtype) + eles_per_256bytes = 256 // dtype_size + return round_up(extent, eles_per_256bytes) + + # `collections` helpers def is_list_of( value: object, diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 252fe06600da..3960392cf74e 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -2,13 +2,17 @@ """CacheEngine class for managing the KV cache.""" from typing import List +import numpy as np import torch +from vllm import envs from vllm.attention import get_attn_backend from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig from vllm.logger import init_logger +from vllm.platforms import current_platform from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, - get_dtype_size, is_pin_memory_available) + align_to_256bytes, get_dtype_size, + is_pin_memory_available) logger = init_logger(__name__) @@ -38,6 +42,7 @@ class CacheEngine: self.num_attention_layers = model_config.get_num_layers_by_block_type( parallel_config, LayerBlockType.attention) self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) + self.align_cache = self._align_cache(model_config) self.block_size = cache_config.block_size self.num_gpu_blocks = cache_config.num_gpu_blocks @@ -75,15 +80,39 @@ class CacheEngine: num_blocks, self.block_size, self.num_kv_heads, self.head_size) pin_memory = is_pin_memory_available() if device == "cpu" else False kv_cache: List[torch.Tensor] = [] + + # Align entries so they are 256 byte aligned for better performance + # Primarily targets MLA as this typically only ends up having entries + # be 128 byte aligned. + if self.align_cache: + # We assume the cache shape is: + # (TOTAL_PAGES, PAGE_SIZE, entry_shape...) + # NOTE this assumption currently only holds for MLA so we only apply + # this optimization when `use_mla` is true + entry_shape = kv_cache_shape[2:] + entry_size = np.prod(entry_shape) + alloc_entry_size = align_to_256bytes(entry_size, self.dtype) + alloc_shape = (*kv_cache_shape[:2], alloc_entry_size) + else: + alloc_shape = kv_cache_shape + for _ in range(self.num_attention_layers): # null block in CpuGpuBlockAllocator requires at least that # block to be zeroed-out. # We zero-out everything for simplicity. - kv_cache.append( - torch.zeros(kv_cache_shape, - dtype=self.dtype, - pin_memory=pin_memory, - device=device)) + layer_kv_cache = torch.zeros(alloc_shape, + dtype=self.dtype, + pin_memory=pin_memory, + device=device) + + # If we allocated with padding for alignment reasons truncate the + # shape while preserving the aligned stride + if self.align_cache: + layer_kv_cache = layer_kv_cache[..., :entry_size] + + # view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases + # when entry_shape is higher than 1D + kv_cache.append(layer_kv_cache.view(kv_cache_shape)) return kv_cache def swap_in(self, src_to_dst: torch.Tensor) -> None: @@ -99,6 +128,14 @@ class CacheEngine: def copy(self, src_to_dsts: torch.Tensor) -> None: self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts) + @staticmethod + def _align_cache(model_config: ModelConfig): + # Currently align_cache only applies to MLA models since the other + # cache kernels haven't been updated yet to support non-continguous + # tensors + return model_config.use_mla and current_platform.is_cuda() \ + and envs.VLLM_CUDA_MEM_ALIGN_KV_CACHE + @staticmethod def get_cache_block_size( cache_config: CacheConfig, @@ -110,14 +147,21 @@ class CacheEngine: num_attention_layers = model_config.get_num_layers_by_block_type( parallel_config, LayerBlockType.attention) - key_cache_block = cache_config.block_size * num_heads * head_size - # For MLA there is no value cache, since the latent vector - # is joint keys and values. - value_cache_block = key_cache_block if not model_config.use_mla else 0 - total = num_attention_layers * (key_cache_block + value_cache_block) if cache_config.cache_dtype == "auto": dtype = model_config.dtype else: dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + + key_cache_entry = num_heads * head_size + if CacheEngine._align_cache(model_config): + key_cache_entry = align_to_256bytes(key_cache_entry, + model_config.dtype) + + # For MLA there is no value cache, since the latent vector + # is joint keys and values. + value_cache_entry = key_cache_entry if not model_config.use_mla else 0 + total = num_attention_layers * cache_config.block_size * \ + (key_cache_entry + value_cache_entry) + dtype_size = get_dtype_size(dtype) return dtype_size * total