[Perf] Mem align KV caches for CUDA devices (MLA perf improvement) (#12676)

Signed-off-by: simon-mo <xmo@berkeley.edu>
Signed-off-by: Lucas Wilkinson <lcwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Co-authored-by: simon-mo <xmo@berkeley.edu>
This commit is contained in:
Lucas Wilkinson 2025-02-04 21:22:24 -05:00 committed by GitHub
parent 233df6f5c4
commit 75e94309e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 429 additions and 34 deletions

View File

@ -15,6 +15,9 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> const& value_caches, std::vector<torch::Tensor> const& value_caches,
const torch::Tensor& block_mapping); const torch::Tensor& block_mapping);
void copy_blocks_mla(std::vector<torch::Tensor> const& kv_caches,
const torch::Tensor& block_mapping);
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping, torch::Tensor& slot_mapping,

View File

@ -46,7 +46,10 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
char* src_ptr = static_cast<char*>(src.data_ptr()); char* src_ptr = static_cast<char*>(src.data_ptr());
char* dst_ptr = static_cast<char*>(dst.data_ptr()); char* dst_ptr = static_cast<char*>(dst.data_ptr());
const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); // We use the stride instead of numel in case the cache is padded for memory
// alignment reasons, we assume the blocks data (inclusive of any padding)
// is contiguous in memory
const int64_t block_size_in_bytes = src.element_size() * src.stride(0);
const at::cuda::OptionalCUDAGuard device_guard( const at::cuda::OptionalCUDAGuard device_guard(
src_device.is_cuda() ? src_device : dst_device); src_device.is_cuda() ? src_device : dst_device);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
@ -93,6 +96,24 @@ __global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,
} }
} }
// Kernel for MLA, which works on a single joint kv_cache
// Grid: (num_layers, num_pairs)
template <typename scalar_t>
__global__ void copy_blocks_mla_kernel(
int64_t* cache_ptrs, const int64_t* __restrict__ block_mapping,
const int mem_footprint_per_block) {
const int layer_idx = blockIdx.x;
const int pair_idx = blockIdx.y;
scalar_t* cache = reinterpret_cast<scalar_t*>(cache_ptrs[layer_idx]);
int64_t src_block = block_mapping[2 * pair_idx];
int64_t dst_block = block_mapping[2 * pair_idx + 1];
int64_t src_offset = src_block * mem_footprint_per_block;
int64_t dst_offset = dst_block * mem_footprint_per_block;
for (int i = threadIdx.x; i < mem_footprint_per_block; i += blockDim.x) {
cache[dst_offset + i] = cache[src_offset + i];
}
}
} // namespace vllm } // namespace vllm
// Note: the key_caches and value_caches vectors are constant but // Note: the key_caches and value_caches vectors are constant but
@ -147,6 +168,42 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
})); }));
} }
// copy blocks kernel for MLA (assumes a joint KV-cache)
void copy_blocks_mla(std::vector<torch::Tensor> const& kv_caches,
const torch::Tensor& block_mapping) {
int num_layers = kv_caches.size();
if (num_layers == 0) {
return;
}
torch::Device cache_device = kv_caches[0].device();
TORCH_CHECK(cache_device.is_cuda(), "kv_cache must be on CUDA");
std::vector<int64_t> cache_ptrs(num_layers);
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
cache_ptrs[layer_idx] =
reinterpret_cast<int64_t>(kv_caches[layer_idx].data_ptr());
}
torch::Tensor cache_ptrs_tensor =
torch::from_blob(cache_ptrs.data(), {num_layers}, torch::kInt64)
.to(cache_device);
int num_pairs = block_mapping.size(0);
// We use the stride instead of numel in case the cache is padded for memory
// alignment reasons, we assume the blocks data (inclusive of any padding)
// is contiguous in memory
int mem_footprint_per_block = kv_caches[0].stride(0);
dim3 grid(num_layers, num_pairs);
dim3 block(std::min(1024, mem_footprint_per_block));
const at::cuda::OptionalCUDAGuard device_guard(cache_device);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
kv_caches[0].scalar_type(), "copy_blocks_mla_kernel", ([&] {
vllm::copy_blocks_mla_kernel<scalar_t><<<grid, block, 0, stream>>>(
cache_ptrs_tensor.data_ptr<int64_t>(),
block_mapping.data_ptr<int64_t>(), mem_footprint_per_block);
}));
}
namespace vllm { namespace vllm {
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt> template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
@ -254,6 +311,7 @@ __global__ void concat_and_cache_mla_kernel(
// + pe_dim)] // + pe_dim)]
const int64_t* __restrict__ slot_mapping, // [num_tokens] const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int block_stride, // const int block_stride, //
const int entry_stride, //
const int kv_c_stride, // const int kv_c_stride, //
const int k_pe_stride, // const int k_pe_stride, //
const int kv_lora_rank, // const int kv_lora_rank, //
@ -274,9 +332,8 @@ __global__ void concat_and_cache_mla_kernel(
int src_stride, int dst_stride, int size, int offset) { int src_stride, int dst_stride, int size, int offset) {
for (int i = threadIdx.x; i < size; i += blockDim.x) { for (int i = threadIdx.x; i < size; i += blockDim.x) {
const int64_t src_idx = token_idx * src_stride + i; const int64_t src_idx = token_idx * src_stride + i;
const int64_t dst_idx = block_idx * block_stride + const int64_t dst_idx =
block_offset * (kv_lora_rank + pe_dim) + i + block_idx * block_stride + block_offset * entry_stride + i + offset;
offset;
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
dst[dst_idx] = src[src_idx]; dst[dst_idx] = src[src_idx];
} else { } else {
@ -397,8 +454,8 @@ void reshape_and_cache_flash(
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \ reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \ reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \ reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, kv_c_stride, \ slot_mapping.data_ptr<int64_t>(), block_stride, entry_stride, \
k_pe_stride, kv_lora_rank, pe_dim, block_size, \ kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
reinterpret_cast<const float*>(scale.data_ptr())); reinterpret_cast<const float*>(scale.data_ptr()));
void concat_and_cache_mla( void concat_and_cache_mla(
@ -428,6 +485,7 @@ void concat_and_cache_mla(
int kv_c_stride = kv_c.stride(0); int kv_c_stride = kv_c.stride(0);
int k_pe_stride = k_pe.stride(0); int k_pe_stride = k_pe.stride(0);
int block_stride = kv_cache.stride(0); int block_stride = kv_cache.stride(0);
int entry_stride = kv_cache.stride(1);
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min(kv_lora_rank, 512)); dim3 block(std::min(kv_lora_rank, 512));

View File

@ -450,6 +450,10 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
"Tensor block_mapping) -> ()"); "Tensor block_mapping) -> ()");
cache_ops.impl("copy_blocks", torch::kCUDA, &copy_blocks); cache_ops.impl("copy_blocks", torch::kCUDA, &copy_blocks);
cache_ops.def(
"copy_blocks_mla(Tensor(a!)[] kv_caches, Tensor block_mapping) -> ()");
cache_ops.impl("copy_blocks_mla", torch::kCUDA, &copy_blocks_mla);
// Reshape the key and value tensors and cache them. // Reshape the key and value tensors and cache them.
cache_ops.def( cache_ops.def(
"reshape_and_cache(Tensor key, Tensor value," "reshape_and_cache(Tensor key, Tensor value,"

View File

@ -9,6 +9,7 @@ import torch
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import align_to_256bytes
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
DTYPES = [torch.half, torch.bfloat16, torch.float] DTYPES = [torch.half, torch.bfloat16, torch.float]
@ -18,6 +19,13 @@ NUM_HEADS = [8] # Arbitrary values for testing
HEAD_SIZES = [64, 80, 120, 256] HEAD_SIZES = [64, 80, 120, 256]
BLOCK_SIZES = [8, 16, 32] BLOCK_SIZES = [8, 16, 32]
# Parameters for MLA tests.
KV_LORA_RANKS = [512]
QK_ROPE_HEAD_DIMS = [64]
NUM_TOKENS_MLA = [42]
BLOCK_SIZES_MLA = [16]
NUM_BLOCKS_MLA = [8]
# Arbitrary values for testing # Arbitrary values for testing
# don't make it too large. e.g. [1024, 36000] will OOM # don't make it too large. e.g. [1024, 36000] will OOM
NUM_BLOCKS = [1024, 10000] NUM_BLOCKS = [1024, 10000]
@ -432,3 +440,257 @@ def test_fp8_e4m3_conversion(
ops.convert_fp8(converted_cache, cache_fp8) ops.convert_fp8(converted_cache, cache_fp8)
torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1) torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1)
def _create_mla_cache(
num_blocks: int,
block_size: int,
entry_size: int,
dtype: torch.dtype,
kv_cache_dtype: str,
device: str,
align_cache: bool,
) -> torch.Tensor:
cache_dtype = torch.uint8 if kv_cache_dtype == "fp8" else dtype
if align_cache:
alloc_entry_size = align_to_256bytes(entry_size, cache_dtype)
alloc_shape = (num_blocks, block_size, alloc_entry_size)
cache_full = torch.zeros(alloc_shape, dtype=cache_dtype, device=device)
cache = cache_full[..., :entry_size]
else:
cache = torch.zeros(num_blocks,
block_size,
entry_size,
dtype=cache_dtype,
device=device)
return cache
def _fill_mla_cache(cache: torch.Tensor, kv_cache_dtype: str):
rand_dtype = torch.float16 if kv_cache_dtype == "fp8" else cache.dtype
vals = torch.randn(*cache.shape, device=cache.device, dtype=rand_dtype)
if kv_cache_dtype == "fp8":
temp = torch.zeros_like(cache)
ops.convert_fp8(temp, vals, 1.0, kv_dtype=kv_cache_dtype)
vals = temp
cache.copy_(vals)
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@pytest.mark.parametrize("align_cache", [False])
@torch.inference_mode()
def test_concat_and_cache_mla(
kv_lora_rank: int,
qk_rope_head_dim: int,
num_tokens: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
seed: int,
device: str,
kv_cache_dtype: str,
align_cache: bool,
) -> None:
current_platform.seed_everything(seed)
torch.set_default_device(device)
total_slots = num_blocks * block_size
slot_mapping_lst = random.sample(range(total_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping_lst,
dtype=torch.long,
device=device)
kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
k_pe = torch.randn(num_tokens,
qk_rope_head_dim,
dtype=dtype,
device=device)
entry_size = kv_lora_rank + qk_rope_head_dim
scale = torch.tensor(0.1, dtype=torch.float32, device=device)
kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device, align_cache)
ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)
for i in range(num_tokens):
slot = slot_mapping[i].item()
block_idx = slot // block_size
block_offset = slot % block_size
ref_temp[block_idx, block_offset, :kv_lora_rank] = kv_c[i]
ref_temp[block_idx, block_offset, kv_lora_rank:] = k_pe[i]
if kv_cache_dtype == "fp8":
ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype)
ops.convert_fp8(ref_kv_cache,
ref_temp,
scale.item(),
kv_dtype=kv_cache_dtype)
else:
ref_kv_cache = ref_temp
opcheck(
torch.ops._C_cache_ops.concat_and_cache_mla,
(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping,
kv_cache_dtype, scale)
if kv_cache_dtype == "fp8":
result_temp = torch.empty_like(kv_cache, dtype=torch.float16)
ops.convert_fp8(result_temp,
kv_cache.contiguous(),
scale.item(),
kv_dtype=kv_cache_dtype)
expected_temp = torch.empty_like(ref_kv_cache, dtype=torch.float16)
ops.convert_fp8(expected_temp,
ref_kv_cache,
scale.item(),
kv_dtype=kv_cache_dtype)
torch.testing.assert_close(result_temp,
expected_temp,
atol=0.001,
rtol=0.1)
else:
torch.testing.assert_close(kv_cache, ref_kv_cache)
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@pytest.mark.parametrize("align_cache", [False, True])
@torch.inference_mode()
def test_copy_blocks_mla(
kv_lora_rank: int,
qk_rope_head_dim: int,
block_size: int,
num_blocks: int,
num_layers: int,
dtype: torch.dtype,
seed: int,
device: str,
kv_cache_dtype: str,
align_cache: bool,
) -> None:
current_platform.seed_everything(seed)
torch.set_default_device(device)
entry_size = kv_lora_rank + qk_rope_head_dim
kv_caches = []
for _ in range(num_layers):
kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device, align_cache)
_fill_mla_cache(kv_cache, kv_cache_dtype=kv_cache_dtype)
kv_caches.append(kv_cache)
ref_caches = [kv_cache.clone() for kv_cache in kv_caches]
num_mappings = min(2, num_blocks // 2)
src_blocks = random.sample(range(num_blocks), num_mappings)
remaining = list(set(range(num_blocks)) - set(src_blocks))
dst_blocks = random.sample(remaining, 2 * num_mappings)
block_mapping = []
for i in range(num_mappings):
src = src_blocks[i]
dst1 = dst_blocks[2 * i]
dst2 = dst_blocks[2 * i + 1]
block_mapping.append((src, dst1))
block_mapping.append((src, dst2))
block_mapping_tensor = torch.tensor(block_mapping,
dtype=torch.int64,
device=device).view(-1, 2)
for src, dst in block_mapping:
for ref_cache in ref_caches:
ref_cache[dst].copy_(ref_cache[src])
opcheck(
torch.ops._C_cache_ops.copy_blocks_mla,
(kv_caches, block_mapping_tensor),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
ops.copy_blocks_mla(kv_caches, block_mapping_tensor)
for kv_cache, ref_cache in zip(kv_caches, ref_caches):
torch.testing.assert_close(kv_cache, ref_cache)
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@pytest.mark.parametrize("align_cache", [False, True])
@torch.inference_mode()
def test_swap_blocks_mla(
kv_lora_rank: int,
qk_rope_head_dim: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
seed: int,
device: str,
kv_cache_dtype: str,
align_cache: bool,
) -> None:
current_platform.seed_everything(seed)
torch.set_default_device(device)
entry_size = kv_lora_rank + qk_rope_head_dim
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device, align_cache)
dst_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device, align_cache)
_fill_mla_cache(src_cache, kv_cache_dtype)
_fill_mla_cache(dst_cache, kv_cache_dtype)
src_cache_clone = src_cache.clone()
num_mappings = min(2, num_blocks // 2)
src_blocks = random.sample(range(num_blocks), num_mappings)
remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
dst_blocks = random.sample(remaining_blocks, num_mappings)
block_mapping = list(zip(src_blocks, dst_blocks))
block_mapping_tensor = torch.tensor(block_mapping,
dtype=torch.int64,
device="cpu").view(-1, 2)
opcheck(
torch.ops._C_cache_ops.swap_blocks,
(src_cache, dst_cache, block_mapping_tensor),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
cond=(kv_lora_rank == KV_LORA_RANKS[0]
and qk_rope_head_dim == QK_ROPE_HEAD_DIMS[0]),
)
ops.swap_blocks(src_cache, dst_cache, block_mapping_tensor)
for src, dst in block_mapping:
torch.testing.assert_close(
src_cache_clone[src].cpu(),
dst_cache[dst].cpu(),
msg=f"Block {src} from src should have been swapped to block "
f"{dst} in dst_cache.")

View File

@ -1037,6 +1037,11 @@ def copy_blocks(key_caches: List[torch.Tensor],
torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping) torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
def copy_blocks_mla(kv_caches: List[torch.Tensor],
block_mapping: torch.Tensor) -> None:
torch.ops._C_cache_ops.copy_blocks_mla(kv_caches, block_mapping)
def swap_blocks(src: torch.Tensor, dst: torch.Tensor, def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
block_mapping: torch.Tensor) -> None: block_mapping: torch.Tensor) -> None:
torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping) torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping)

View File

@ -26,7 +26,6 @@ from vllm.attention.backends.mla.utils import MLACommonImpl, MLACommonMetadata
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx, compute_slot_mapping_start_idx,
is_block_tables_empty) is_block_tables_empty)
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.utils import async_tensor_h2d, make_tensor_with_pad
@ -72,14 +71,14 @@ class TritonMLABackend(AttentionBackend):
dst_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor, src_to_dst: torch.Tensor,
) -> None: ) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
@staticmethod @staticmethod
def copy_blocks( def copy_blocks(
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor, src_to_dists: torch.Tensor,
) -> None: ) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists) ops.copy_blocks_mla(kv_caches, src_to_dists)
@staticmethod @staticmethod
def get_supported_head_sizes() -> List[int]: def get_supported_head_sizes() -> List[int]:

View File

@ -204,10 +204,10 @@ def _decode_att_m_fwd(
Req_to_tokens.stride(0), Req_to_tokens.stride(0),
q.stride(0), q.stride(0),
q.stride(1), q.stride(1),
k_buffer.stride(-2), k_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
k_buffer.stride(-1), k_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
v_buffer.stride(-2), v_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
v_buffer.stride(-1), v_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
att_out.stride(0), att_out.stride(0),
att_out.stride(1), att_out.stride(1),
att_out.stride(2), att_out.stride(2),
@ -438,10 +438,10 @@ def _decode_grouped_att_m_fwd(
Req_to_tokens.stride(0), Req_to_tokens.stride(0),
q.stride(0), q.stride(0),
q.stride(1), q.stride(1),
k_buffer.stride(-2), k_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
k_buffer.stride(-1), k_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
v_buffer.stride(-2), v_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
v_buffer.stride(-1), v_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
att_out.stride(0), att_out.stride(0),
att_out.stride(1), att_out.stride(1),
att_out.stride(2), att_out.stride(2),

View File

@ -82,6 +82,7 @@ if TYPE_CHECKING:
VLLM_MLA_DISABLE: bool = False VLLM_MLA_DISABLE: bool = False
VLLM_MLA_PERFORM_MATRIX_ABSORPTION: bool = True VLLM_MLA_PERFORM_MATRIX_ABSORPTION: bool = True
VLLM_MLA_DISABLE_REQUANTIZATION: bool = False VLLM_MLA_DISABLE_REQUANTIZATION: bool = False
VLLM_MLA_CUDA_MEM_ALIGN_KV_CACHE: bool = True
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False
@ -539,6 +540,15 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON": "VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON":
lambda: bool(int(os.getenv("VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0")) lambda: bool(int(os.getenv("VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
), ),
# When on a Nvidia GPU aligns single entries (within a page) so they are 256
# byte aligned for better performance, this increases the memory usage of
# the cache. Currently this only affects MLA that results in non-256
# byte aligned entries. This matches the alignment the CUDA runtime uses
# for all allocations. Currently this primarily affects MLA, for most other
# models the alignment is already naturally aligned to 256 bytes.
"VLLM_CUDA_MEM_ALIGN_KV_CACHE":
lambda: bool(int(os.getenv("VLLM_CUDA_MEM_ALIGN_KV_CACHE", "1"))),
} }
# end-env-vars-definition # end-env-vars-definition

View File

@ -563,6 +563,10 @@ def cdiv(a: int, b: int) -> int:
return -(a // -b) return -(a // -b)
def round_up(x: int, y: int) -> int:
return ((x + y - 1) // y) * y
def _generate_random_fp8( def _generate_random_fp8(
tensor: torch.Tensor, tensor: torch.Tensor,
low: float, low: float,
@ -794,6 +798,12 @@ def get_dtype_size(dtype: torch.dtype) -> int:
return torch.tensor([], dtype=dtype).element_size() return torch.tensor([], dtype=dtype).element_size()
def align_to_256bytes(extent: int, dtype: torch.dtype) -> int:
dtype_size = get_dtype_size(dtype)
eles_per_256bytes = 256 // dtype_size
return round_up(extent, eles_per_256bytes)
# `collections` helpers # `collections` helpers
def is_list_of( def is_list_of(
value: object, value: object,

View File

@ -2,13 +2,17 @@
"""CacheEngine class for managing the KV cache.""" """CacheEngine class for managing the KV cache."""
from typing import List from typing import List
import numpy as np
import torch import torch
from vllm import envs
from vllm.attention import get_attn_backend from vllm.attention import get_attn_backend
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType,
get_dtype_size, is_pin_memory_available) align_to_256bytes, get_dtype_size,
is_pin_memory_available)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -38,6 +42,7 @@ class CacheEngine:
self.num_attention_layers = model_config.get_num_layers_by_block_type( self.num_attention_layers = model_config.get_num_layers_by_block_type(
parallel_config, LayerBlockType.attention) parallel_config, LayerBlockType.attention)
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
self.align_cache = self._align_cache(model_config)
self.block_size = cache_config.block_size self.block_size = cache_config.block_size
self.num_gpu_blocks = cache_config.num_gpu_blocks self.num_gpu_blocks = cache_config.num_gpu_blocks
@ -75,15 +80,39 @@ class CacheEngine:
num_blocks, self.block_size, self.num_kv_heads, self.head_size) num_blocks, self.block_size, self.num_kv_heads, self.head_size)
pin_memory = is_pin_memory_available() if device == "cpu" else False pin_memory = is_pin_memory_available() if device == "cpu" else False
kv_cache: List[torch.Tensor] = [] kv_cache: List[torch.Tensor] = []
# Align entries so they are 256 byte aligned for better performance
# Primarily targets MLA as this typically only ends up having entries
# be 128 byte aligned.
if self.align_cache:
# We assume the cache shape is:
# (TOTAL_PAGES, PAGE_SIZE, entry_shape...)
# NOTE this assumption currently only holds for MLA so we only apply
# this optimization when `use_mla` is true
entry_shape = kv_cache_shape[2:]
entry_size = np.prod(entry_shape)
alloc_entry_size = align_to_256bytes(entry_size, self.dtype)
alloc_shape = (*kv_cache_shape[:2], alloc_entry_size)
else:
alloc_shape = kv_cache_shape
for _ in range(self.num_attention_layers): for _ in range(self.num_attention_layers):
# null block in CpuGpuBlockAllocator requires at least that # null block in CpuGpuBlockAllocator requires at least that
# block to be zeroed-out. # block to be zeroed-out.
# We zero-out everything for simplicity. # We zero-out everything for simplicity.
kv_cache.append( layer_kv_cache = torch.zeros(alloc_shape,
torch.zeros(kv_cache_shape,
dtype=self.dtype, dtype=self.dtype,
pin_memory=pin_memory, pin_memory=pin_memory,
device=device)) device=device)
# If we allocated with padding for alignment reasons truncate the
# shape while preserving the aligned stride
if self.align_cache:
layer_kv_cache = layer_kv_cache[..., :entry_size]
# view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases
# when entry_shape is higher than 1D
kv_cache.append(layer_kv_cache.view(kv_cache_shape))
return kv_cache return kv_cache
def swap_in(self, src_to_dst: torch.Tensor) -> None: def swap_in(self, src_to_dst: torch.Tensor) -> None:
@ -99,6 +128,14 @@ class CacheEngine:
def copy(self, src_to_dsts: torch.Tensor) -> None: def copy(self, src_to_dsts: torch.Tensor) -> None:
self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts) self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts)
@staticmethod
def _align_cache(model_config: ModelConfig):
# Currently align_cache only applies to MLA models since the other
# cache kernels haven't been updated yet to support non-continguous
# tensors
return model_config.use_mla and current_platform.is_cuda() \
and envs.VLLM_CUDA_MEM_ALIGN_KV_CACHE
@staticmethod @staticmethod
def get_cache_block_size( def get_cache_block_size(
cache_config: CacheConfig, cache_config: CacheConfig,
@ -110,14 +147,21 @@ class CacheEngine:
num_attention_layers = model_config.get_num_layers_by_block_type( num_attention_layers = model_config.get_num_layers_by_block_type(
parallel_config, LayerBlockType.attention) parallel_config, LayerBlockType.attention)
key_cache_block = cache_config.block_size * num_heads * head_size
# For MLA there is no value cache, since the latent vector
# is joint keys and values.
value_cache_block = key_cache_block if not model_config.use_mla else 0
total = num_attention_layers * (key_cache_block + value_cache_block)
if cache_config.cache_dtype == "auto": if cache_config.cache_dtype == "auto":
dtype = model_config.dtype dtype = model_config.dtype
else: else:
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
key_cache_entry = num_heads * head_size
if CacheEngine._align_cache(model_config):
key_cache_entry = align_to_256bytes(key_cache_entry,
model_config.dtype)
# For MLA there is no value cache, since the latent vector
# is joint keys and values.
value_cache_entry = key_cache_entry if not model_config.use_mla else 0
total = num_attention_layers * cache_config.block_size * \
(key_cache_entry + value_cache_entry)
dtype_size = get_dtype_size(dtype) dtype_size = get_dtype_size(dtype)
return dtype_size * total return dtype_size * total