From 5952d8ab61a39eefd3617b7d46b7a6bd87f51259 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Sat, 15 Mar 2025 01:08:25 -0400 Subject: [PATCH] [Attention] Get rid of mla cache alignment (#14842) Signed-off-by: Lucas Wilkinson --- tests/kernels/test_cache.py | 39 ++++++++++------------------------ vllm/envs.py | 10 --------- vllm/utils.py | 6 ------ vllm/worker/cache_engine.py | 42 +++---------------------------------- 4 files changed, 14 insertions(+), 83 deletions(-) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index b55ebd967fd7c..f7936989c9639 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -8,7 +8,6 @@ import torch from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck from vllm import _custom_ops as ops from vllm.platforms import current_platform -from vllm.utils import align_to_256bytes COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] DTYPES = [torch.half, torch.bfloat16, torch.float] @@ -450,22 +449,13 @@ def _create_mla_cache( dtype: torch.dtype, kv_cache_dtype: str, device: str, - align_cache: bool, ) -> torch.Tensor: cache_dtype = torch.uint8 if kv_cache_dtype == "fp8" else dtype - - if align_cache: - alloc_entry_size = align_to_256bytes(entry_size, cache_dtype) - alloc_shape = (num_blocks, block_size, alloc_entry_size) - cache_full = torch.zeros(alloc_shape, dtype=cache_dtype, device=device) - cache = cache_full[..., :entry_size] - else: - cache = torch.zeros(num_blocks, - block_size, - entry_size, - dtype=cache_dtype, - device=device) - return cache + return torch.zeros(num_blocks, + block_size, + entry_size, + dtype=cache_dtype, + device=device) def _fill_mla_cache(cache: torch.Tensor, kv_cache_dtype: str): @@ -488,7 +478,6 @@ def _fill_mla_cache(cache: torch.Tensor, kv_cache_dtype: str): @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) -@pytest.mark.parametrize("align_cache", [False]) @torch.inference_mode() def test_concat_and_cache_mla( kv_lora_rank: int, @@ -500,7 +489,6 @@ def test_concat_and_cache_mla( seed: int, device: str, kv_cache_dtype: str, - align_cache: bool, ) -> None: current_platform.seed_everything(seed) torch.set_default_device(device) @@ -520,7 +508,7 @@ def test_concat_and_cache_mla( scale = torch.tensor(0.1, dtype=torch.float32, device=device) kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, - kv_cache_dtype, device, align_cache) + kv_cache_dtype, device) ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device) for i in range(num_tokens): @@ -576,7 +564,6 @@ def test_concat_and_cache_mla( @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) -@pytest.mark.parametrize("align_cache", [False, True]) @torch.inference_mode() def test_copy_blocks_mla( kv_lora_rank: int, @@ -588,7 +575,6 @@ def test_copy_blocks_mla( seed: int, device: str, kv_cache_dtype: str, - align_cache: bool, ) -> None: current_platform.seed_everything(seed) torch.set_default_device(device) @@ -598,7 +584,7 @@ def test_copy_blocks_mla( kv_caches = [] for _ in range(num_layers): kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, - kv_cache_dtype, device, align_cache) + kv_cache_dtype, device) _fill_mla_cache(kv_cache, kv_cache_dtype=kv_cache_dtype) kv_caches.append(kv_cache) @@ -642,7 +628,6 @@ def test_copy_blocks_mla( @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) -@pytest.mark.parametrize("align_cache", [False, True]) @torch.inference_mode() def test_swap_blocks_mla( kv_lora_rank: int, @@ -653,7 +638,6 @@ def test_swap_blocks_mla( seed: int, device: str, kv_cache_dtype: str, - align_cache: bool, ) -> None: current_platform.seed_everything(seed) torch.set_default_device(device) @@ -661,9 +645,9 @@ def test_swap_blocks_mla( entry_size = kv_lora_rank + qk_rope_head_dim src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, - kv_cache_dtype, device, align_cache) + kv_cache_dtype, device) dst_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, - kv_cache_dtype, device, align_cache) + kv_cache_dtype, device) _fill_mla_cache(src_cache, kv_cache_dtype) _fill_mla_cache(dst_cache, kv_cache_dtype) @@ -704,15 +688,14 @@ def test_swap_blocks_mla( @pytest.mark.parametrize("dtype", [torch.float32]) @pytest.mark.parametrize("kv_cache_dtype", ["auto"]) # You can also test "fp8" if needed. -@pytest.mark.parametrize("align_cache", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size, num_blocks, max_seq_len, batch_size, dtype, - kv_cache_dtype, align_cache, device): + kv_cache_dtype, device): entry_size = kv_lora_rank + qk_rope_head_dim src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, - kv_cache_dtype, device, align_cache) + kv_cache_dtype, device) _fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype) seq_len_tensor = torch.randint(0, diff --git a/vllm/envs.py b/vllm/envs.py index 7e079006b273c..463059dc06704 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -84,7 +84,6 @@ if TYPE_CHECKING: VLLM_SERVER_DEV_MODE: bool = False VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128 VLLM_MLA_DISABLE: bool = False - VLLM_MLA_CUDA_MEM_ALIGN_KV_CACHE: bool = True VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False VLLM_RAY_PER_WORKER_GPUS: float = 1.0 VLLM_RAY_BUNDLE_INDICES: str = "" @@ -580,15 +579,6 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_RAY_BUNDLE_INDICES": lambda: os.getenv("VLLM_RAY_BUNDLE_INDICES", ""), - # When on a Nvidia GPU aligns single entries (within a page) so they are 256 - # byte aligned for better performance, this increases the memory usage of - # the cache. Currently this only affects MLA that results in non-256 - # byte aligned entries. This matches the alignment the CUDA runtime uses - # for all allocations. Currently this primarily affects MLA, for most other - # models the alignment is already naturally aligned to 256 bytes. - "VLLM_CUDA_MEM_ALIGN_KV_CACHE": - lambda: bool(int(os.getenv("VLLM_CUDA_MEM_ALIGN_KV_CACHE", "1"))), - # In some system, find_loaded_library() may not work. So we allow users to # specify the path through environment variable VLLM_CUDART_SO_PATH. "VLLM_CUDART_SO_PATH": diff --git a/vllm/utils.py b/vllm/utils.py index a8eba27dbcdbd..9334741225008 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -827,12 +827,6 @@ def get_dtype_size(dtype: torch.dtype) -> int: return torch.tensor([], dtype=dtype).element_size() -def align_to_256bytes(extent: int, dtype: torch.dtype) -> int: - dtype_size = get_dtype_size(dtype) - eles_per_256bytes = 256 // dtype_size - return round_up(extent, eles_per_256bytes) - - # `collections` helpers def is_list_of( value: object, diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 004b4e4b757fd..85ebe8121e524 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -1,18 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 """CacheEngine class for managing the KV cache.""" -from math import prod from typing import List import torch -from vllm import envs from vllm.attention import get_attn_backend from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig from vllm.logger import init_logger -from vllm.platforms import current_platform from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, - align_to_256bytes, get_dtype_size, - is_pin_memory_available) + get_dtype_size, is_pin_memory_available) logger = init_logger(__name__) @@ -42,7 +38,6 @@ class CacheEngine: self.num_attention_layers = model_config.get_num_layers_by_block_type( parallel_config, LayerBlockType.attention) self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) - self.align_cache = self._align_cache(model_config) self.block_size = cache_config.block_size self.num_gpu_blocks = cache_config.num_gpu_blocks @@ -81,38 +76,18 @@ class CacheEngine: pin_memory = is_pin_memory_available() if device == "cpu" else False kv_cache: List[torch.Tensor] = [] - # Align entries so they are 256 byte aligned for better performance - # Primarily targets MLA as this typically only ends up having entries - # be 128 byte aligned. - if self.align_cache: - # We assume the cache shape is: - # (TOTAL_PAGES, PAGE_SIZE, entry_shape...) - # NOTE this assumption currently only holds for MLA so we only apply - # this optimization when `use_mla` is true - entry_shape = kv_cache_shape[2:] - entry_size = prod(entry_shape) - alloc_entry_size = align_to_256bytes(entry_size, self.dtype) - alloc_shape = (*kv_cache_shape[:2], alloc_entry_size) - else: - alloc_shape = kv_cache_shape - for _ in range(self.num_attention_layers): # null block in CpuGpuBlockAllocator requires at least that # block to be zeroed-out. # We zero-out everything for simplicity. - layer_kv_cache = torch.zeros(alloc_shape, + layer_kv_cache = torch.zeros(kv_cache_shape, dtype=self.dtype, pin_memory=pin_memory, device=device) - # If we allocated with padding for alignment reasons truncate the - # shape while preserving the aligned stride - if self.align_cache: - layer_kv_cache = layer_kv_cache[..., :entry_size] - # view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases # when entry_shape is higher than 1D - kv_cache.append(layer_kv_cache.view(kv_cache_shape)) + kv_cache.append(layer_kv_cache) return kv_cache def swap_in(self, src_to_dst: torch.Tensor) -> None: @@ -128,14 +103,6 @@ class CacheEngine: def copy(self, src_to_dsts: torch.Tensor) -> None: self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts) - @staticmethod - def _align_cache(model_config: ModelConfig): - # Currently align_cache only applies to MLA models since the other - # cache kernels haven't been updated yet to support non-continguous - # tensors - return model_config.use_mla and current_platform.is_cuda() \ - and envs.VLLM_CUDA_MEM_ALIGN_KV_CACHE - @staticmethod def get_cache_block_size( cache_config: CacheConfig, @@ -153,9 +120,6 @@ class CacheEngine: dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] key_cache_entry = num_heads * head_size - if CacheEngine._align_cache(model_config): - key_cache_entry = align_to_256bytes(key_cache_entry, - model_config.dtype) # For MLA there is no value cache, since the latent vector # is joint keys and values.