[Attention] Get rid of mla cache alignment (#14842)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson 2025-03-15 01:08:25 -04:00 committed by GitHub
parent a2ae496589
commit 5952d8ab61
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 14 additions and 83 deletions

View File

@ -8,7 +8,6 @@ import torch
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import align_to_256bytes
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
DTYPES = [torch.half, torch.bfloat16, torch.float] DTYPES = [torch.half, torch.bfloat16, torch.float]
@ -450,22 +449,13 @@ def _create_mla_cache(
dtype: torch.dtype, dtype: torch.dtype,
kv_cache_dtype: str, kv_cache_dtype: str,
device: str, device: str,
align_cache: bool,
) -> torch.Tensor: ) -> torch.Tensor:
cache_dtype = torch.uint8 if kv_cache_dtype == "fp8" else dtype cache_dtype = torch.uint8 if kv_cache_dtype == "fp8" else dtype
return torch.zeros(num_blocks,
if align_cache: block_size,
alloc_entry_size = align_to_256bytes(entry_size, cache_dtype) entry_size,
alloc_shape = (num_blocks, block_size, alloc_entry_size) dtype=cache_dtype,
cache_full = torch.zeros(alloc_shape, dtype=cache_dtype, device=device) device=device)
cache = cache_full[..., :entry_size]
else:
cache = torch.zeros(num_blocks,
block_size,
entry_size,
dtype=cache_dtype,
device=device)
return cache
def _fill_mla_cache(cache: torch.Tensor, kv_cache_dtype: str): def _fill_mla_cache(cache: torch.Tensor, kv_cache_dtype: str):
@ -488,7 +478,6 @@ def _fill_mla_cache(cache: torch.Tensor, kv_cache_dtype: str):
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@pytest.mark.parametrize("align_cache", [False])
@torch.inference_mode() @torch.inference_mode()
def test_concat_and_cache_mla( def test_concat_and_cache_mla(
kv_lora_rank: int, kv_lora_rank: int,
@ -500,7 +489,6 @@ def test_concat_and_cache_mla(
seed: int, seed: int,
device: str, device: str,
kv_cache_dtype: str, kv_cache_dtype: str,
align_cache: bool,
) -> None: ) -> None:
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
torch.set_default_device(device) torch.set_default_device(device)
@ -520,7 +508,7 @@ def test_concat_and_cache_mla(
scale = torch.tensor(0.1, dtype=torch.float32, device=device) scale = torch.tensor(0.1, dtype=torch.float32, device=device)
kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device, align_cache) kv_cache_dtype, device)
ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device) ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)
for i in range(num_tokens): for i in range(num_tokens):
@ -576,7 +564,6 @@ def test_concat_and_cache_mla(
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@pytest.mark.parametrize("align_cache", [False, True])
@torch.inference_mode() @torch.inference_mode()
def test_copy_blocks_mla( def test_copy_blocks_mla(
kv_lora_rank: int, kv_lora_rank: int,
@ -588,7 +575,6 @@ def test_copy_blocks_mla(
seed: int, seed: int,
device: str, device: str,
kv_cache_dtype: str, kv_cache_dtype: str,
align_cache: bool,
) -> None: ) -> None:
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
torch.set_default_device(device) torch.set_default_device(device)
@ -598,7 +584,7 @@ def test_copy_blocks_mla(
kv_caches = [] kv_caches = []
for _ in range(num_layers): for _ in range(num_layers):
kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device, align_cache) kv_cache_dtype, device)
_fill_mla_cache(kv_cache, kv_cache_dtype=kv_cache_dtype) _fill_mla_cache(kv_cache, kv_cache_dtype=kv_cache_dtype)
kv_caches.append(kv_cache) kv_caches.append(kv_cache)
@ -642,7 +628,6 @@ def test_copy_blocks_mla(
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@pytest.mark.parametrize("align_cache", [False, True])
@torch.inference_mode() @torch.inference_mode()
def test_swap_blocks_mla( def test_swap_blocks_mla(
kv_lora_rank: int, kv_lora_rank: int,
@ -653,7 +638,6 @@ def test_swap_blocks_mla(
seed: int, seed: int,
device: str, device: str,
kv_cache_dtype: str, kv_cache_dtype: str,
align_cache: bool,
) -> None: ) -> None:
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
torch.set_default_device(device) torch.set_default_device(device)
@ -661,9 +645,9 @@ def test_swap_blocks_mla(
entry_size = kv_lora_rank + qk_rope_head_dim entry_size = kv_lora_rank + qk_rope_head_dim
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device, align_cache) kv_cache_dtype, device)
dst_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, dst_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device, align_cache) kv_cache_dtype, device)
_fill_mla_cache(src_cache, kv_cache_dtype) _fill_mla_cache(src_cache, kv_cache_dtype)
_fill_mla_cache(dst_cache, kv_cache_dtype) _fill_mla_cache(dst_cache, kv_cache_dtype)
@ -704,15 +688,14 @@ def test_swap_blocks_mla(
@pytest.mark.parametrize("dtype", [torch.float32]) @pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("kv_cache_dtype", @pytest.mark.parametrize("kv_cache_dtype",
["auto"]) # You can also test "fp8" if needed. ["auto"]) # You can also test "fp8" if needed.
@pytest.mark.parametrize("align_cache", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode() @torch.inference_mode()
def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size, def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
num_blocks, max_seq_len, batch_size, dtype, num_blocks, max_seq_len, batch_size, dtype,
kv_cache_dtype, align_cache, device): kv_cache_dtype, device):
entry_size = kv_lora_rank + qk_rope_head_dim entry_size = kv_lora_rank + qk_rope_head_dim
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device, align_cache) kv_cache_dtype, device)
_fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype) _fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)
seq_len_tensor = torch.randint(0, seq_len_tensor = torch.randint(0,

View File

@ -84,7 +84,6 @@ if TYPE_CHECKING:
VLLM_SERVER_DEV_MODE: bool = False VLLM_SERVER_DEV_MODE: bool = False
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128 VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128
VLLM_MLA_DISABLE: bool = False VLLM_MLA_DISABLE: bool = False
VLLM_MLA_CUDA_MEM_ALIGN_KV_CACHE: bool = True
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False
VLLM_RAY_PER_WORKER_GPUS: float = 1.0 VLLM_RAY_PER_WORKER_GPUS: float = 1.0
VLLM_RAY_BUNDLE_INDICES: str = "" VLLM_RAY_BUNDLE_INDICES: str = ""
@ -580,15 +579,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_RAY_BUNDLE_INDICES": "VLLM_RAY_BUNDLE_INDICES":
lambda: os.getenv("VLLM_RAY_BUNDLE_INDICES", ""), lambda: os.getenv("VLLM_RAY_BUNDLE_INDICES", ""),
# When on a Nvidia GPU aligns single entries (within a page) so they are 256
# byte aligned for better performance, this increases the memory usage of
# the cache. Currently this only affects MLA that results in non-256
# byte aligned entries. This matches the alignment the CUDA runtime uses
# for all allocations. Currently this primarily affects MLA, for most other
# models the alignment is already naturally aligned to 256 bytes.
"VLLM_CUDA_MEM_ALIGN_KV_CACHE":
lambda: bool(int(os.getenv("VLLM_CUDA_MEM_ALIGN_KV_CACHE", "1"))),
# In some system, find_loaded_library() may not work. So we allow users to # In some system, find_loaded_library() may not work. So we allow users to
# specify the path through environment variable VLLM_CUDART_SO_PATH. # specify the path through environment variable VLLM_CUDART_SO_PATH.
"VLLM_CUDART_SO_PATH": "VLLM_CUDART_SO_PATH":

View File

@ -827,12 +827,6 @@ def get_dtype_size(dtype: torch.dtype) -> int:
return torch.tensor([], dtype=dtype).element_size() return torch.tensor([], dtype=dtype).element_size()
def align_to_256bytes(extent: int, dtype: torch.dtype) -> int:
dtype_size = get_dtype_size(dtype)
eles_per_256bytes = 256 // dtype_size
return round_up(extent, eles_per_256bytes)
# `collections` helpers # `collections` helpers
def is_list_of( def is_list_of(
value: object, value: object,

View File

@ -1,18 +1,14 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""CacheEngine class for managing the KV cache.""" """CacheEngine class for managing the KV cache."""
from math import prod
from typing import List from typing import List
import torch import torch
from vllm import envs
from vllm.attention import get_attn_backend from vllm.attention import get_attn_backend
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType,
align_to_256bytes, get_dtype_size, get_dtype_size, is_pin_memory_available)
is_pin_memory_available)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -42,7 +38,6 @@ class CacheEngine:
self.num_attention_layers = model_config.get_num_layers_by_block_type( self.num_attention_layers = model_config.get_num_layers_by_block_type(
parallel_config, LayerBlockType.attention) parallel_config, LayerBlockType.attention)
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
self.align_cache = self._align_cache(model_config)
self.block_size = cache_config.block_size self.block_size = cache_config.block_size
self.num_gpu_blocks = cache_config.num_gpu_blocks self.num_gpu_blocks = cache_config.num_gpu_blocks
@ -81,38 +76,18 @@ class CacheEngine:
pin_memory = is_pin_memory_available() if device == "cpu" else False pin_memory = is_pin_memory_available() if device == "cpu" else False
kv_cache: List[torch.Tensor] = [] kv_cache: List[torch.Tensor] = []
# Align entries so they are 256 byte aligned for better performance
# Primarily targets MLA as this typically only ends up having entries
# be 128 byte aligned.
if self.align_cache:
# We assume the cache shape is:
# (TOTAL_PAGES, PAGE_SIZE, entry_shape...)
# NOTE this assumption currently only holds for MLA so we only apply
# this optimization when `use_mla` is true
entry_shape = kv_cache_shape[2:]
entry_size = prod(entry_shape)
alloc_entry_size = align_to_256bytes(entry_size, self.dtype)
alloc_shape = (*kv_cache_shape[:2], alloc_entry_size)
else:
alloc_shape = kv_cache_shape
for _ in range(self.num_attention_layers): for _ in range(self.num_attention_layers):
# null block in CpuGpuBlockAllocator requires at least that # null block in CpuGpuBlockAllocator requires at least that
# block to be zeroed-out. # block to be zeroed-out.
# We zero-out everything for simplicity. # We zero-out everything for simplicity.
layer_kv_cache = torch.zeros(alloc_shape, layer_kv_cache = torch.zeros(kv_cache_shape,
dtype=self.dtype, dtype=self.dtype,
pin_memory=pin_memory, pin_memory=pin_memory,
device=device) device=device)
# If we allocated with padding for alignment reasons truncate the
# shape while preserving the aligned stride
if self.align_cache:
layer_kv_cache = layer_kv_cache[..., :entry_size]
# view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases # view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases
# when entry_shape is higher than 1D # when entry_shape is higher than 1D
kv_cache.append(layer_kv_cache.view(kv_cache_shape)) kv_cache.append(layer_kv_cache)
return kv_cache return kv_cache
def swap_in(self, src_to_dst: torch.Tensor) -> None: def swap_in(self, src_to_dst: torch.Tensor) -> None:
@ -128,14 +103,6 @@ class CacheEngine:
def copy(self, src_to_dsts: torch.Tensor) -> None: def copy(self, src_to_dsts: torch.Tensor) -> None:
self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts) self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts)
@staticmethod
def _align_cache(model_config: ModelConfig):
# Currently align_cache only applies to MLA models since the other
# cache kernels haven't been updated yet to support non-continguous
# tensors
return model_config.use_mla and current_platform.is_cuda() \
and envs.VLLM_CUDA_MEM_ALIGN_KV_CACHE
@staticmethod @staticmethod
def get_cache_block_size( def get_cache_block_size(
cache_config: CacheConfig, cache_config: CacheConfig,
@ -153,9 +120,6 @@ class CacheEngine:
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
key_cache_entry = num_heads * head_size key_cache_entry = num_heads * head_size
if CacheEngine._align_cache(model_config):
key_cache_entry = align_to_256bytes(key_cache_entry,
model_config.dtype)
# For MLA there is no value cache, since the latent vector # For MLA there is no value cache, since the latent vector
# is joint keys and values. # is joint keys and values.