mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-07 13:42:16 +08:00
[Attention] Get rid of mla cache alignment (#14842)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
a2ae496589
commit
5952d8ab61
@ -8,7 +8,6 @@ import torch
|
|||||||
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
|
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import align_to_256bytes
|
|
||||||
|
|
||||||
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
|
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
|
||||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||||
@ -450,22 +449,13 @@ def _create_mla_cache(
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
device: str,
|
device: str,
|
||||||
align_cache: bool,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
cache_dtype = torch.uint8 if kv_cache_dtype == "fp8" else dtype
|
cache_dtype = torch.uint8 if kv_cache_dtype == "fp8" else dtype
|
||||||
|
return torch.zeros(num_blocks,
|
||||||
if align_cache:
|
block_size,
|
||||||
alloc_entry_size = align_to_256bytes(entry_size, cache_dtype)
|
entry_size,
|
||||||
alloc_shape = (num_blocks, block_size, alloc_entry_size)
|
dtype=cache_dtype,
|
||||||
cache_full = torch.zeros(alloc_shape, dtype=cache_dtype, device=device)
|
device=device)
|
||||||
cache = cache_full[..., :entry_size]
|
|
||||||
else:
|
|
||||||
cache = torch.zeros(num_blocks,
|
|
||||||
block_size,
|
|
||||||
entry_size,
|
|
||||||
dtype=cache_dtype,
|
|
||||||
device=device)
|
|
||||||
return cache
|
|
||||||
|
|
||||||
|
|
||||||
def _fill_mla_cache(cache: torch.Tensor, kv_cache_dtype: str):
|
def _fill_mla_cache(cache: torch.Tensor, kv_cache_dtype: str):
|
||||||
@ -488,7 +478,6 @@ def _fill_mla_cache(cache: torch.Tensor, kv_cache_dtype: str):
|
|||||||
@pytest.mark.parametrize("seed", SEEDS)
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
||||||
@pytest.mark.parametrize("align_cache", [False])
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_concat_and_cache_mla(
|
def test_concat_and_cache_mla(
|
||||||
kv_lora_rank: int,
|
kv_lora_rank: int,
|
||||||
@ -500,7 +489,6 @@ def test_concat_and_cache_mla(
|
|||||||
seed: int,
|
seed: int,
|
||||||
device: str,
|
device: str,
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
align_cache: bool,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
current_platform.seed_everything(seed)
|
current_platform.seed_everything(seed)
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
@ -520,7 +508,7 @@ def test_concat_and_cache_mla(
|
|||||||
|
|
||||||
scale = torch.tensor(0.1, dtype=torch.float32, device=device)
|
scale = torch.tensor(0.1, dtype=torch.float32, device=device)
|
||||||
kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
||||||
kv_cache_dtype, device, align_cache)
|
kv_cache_dtype, device)
|
||||||
ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)
|
ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)
|
||||||
|
|
||||||
for i in range(num_tokens):
|
for i in range(num_tokens):
|
||||||
@ -576,7 +564,6 @@ def test_concat_and_cache_mla(
|
|||||||
@pytest.mark.parametrize("seed", SEEDS)
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
||||||
@pytest.mark.parametrize("align_cache", [False, True])
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_copy_blocks_mla(
|
def test_copy_blocks_mla(
|
||||||
kv_lora_rank: int,
|
kv_lora_rank: int,
|
||||||
@ -588,7 +575,6 @@ def test_copy_blocks_mla(
|
|||||||
seed: int,
|
seed: int,
|
||||||
device: str,
|
device: str,
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
align_cache: bool,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
current_platform.seed_everything(seed)
|
current_platform.seed_everything(seed)
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
@ -598,7 +584,7 @@ def test_copy_blocks_mla(
|
|||||||
kv_caches = []
|
kv_caches = []
|
||||||
for _ in range(num_layers):
|
for _ in range(num_layers):
|
||||||
kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
||||||
kv_cache_dtype, device, align_cache)
|
kv_cache_dtype, device)
|
||||||
_fill_mla_cache(kv_cache, kv_cache_dtype=kv_cache_dtype)
|
_fill_mla_cache(kv_cache, kv_cache_dtype=kv_cache_dtype)
|
||||||
kv_caches.append(kv_cache)
|
kv_caches.append(kv_cache)
|
||||||
|
|
||||||
@ -642,7 +628,6 @@ def test_copy_blocks_mla(
|
|||||||
@pytest.mark.parametrize("seed", SEEDS)
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
||||||
@pytest.mark.parametrize("align_cache", [False, True])
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_swap_blocks_mla(
|
def test_swap_blocks_mla(
|
||||||
kv_lora_rank: int,
|
kv_lora_rank: int,
|
||||||
@ -653,7 +638,6 @@ def test_swap_blocks_mla(
|
|||||||
seed: int,
|
seed: int,
|
||||||
device: str,
|
device: str,
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
align_cache: bool,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
current_platform.seed_everything(seed)
|
current_platform.seed_everything(seed)
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
@ -661,9 +645,9 @@ def test_swap_blocks_mla(
|
|||||||
entry_size = kv_lora_rank + qk_rope_head_dim
|
entry_size = kv_lora_rank + qk_rope_head_dim
|
||||||
|
|
||||||
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
||||||
kv_cache_dtype, device, align_cache)
|
kv_cache_dtype, device)
|
||||||
dst_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
dst_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
||||||
kv_cache_dtype, device, align_cache)
|
kv_cache_dtype, device)
|
||||||
|
|
||||||
_fill_mla_cache(src_cache, kv_cache_dtype)
|
_fill_mla_cache(src_cache, kv_cache_dtype)
|
||||||
_fill_mla_cache(dst_cache, kv_cache_dtype)
|
_fill_mla_cache(dst_cache, kv_cache_dtype)
|
||||||
@ -704,15 +688,14 @@ def test_swap_blocks_mla(
|
|||||||
@pytest.mark.parametrize("dtype", [torch.float32])
|
@pytest.mark.parametrize("dtype", [torch.float32])
|
||||||
@pytest.mark.parametrize("kv_cache_dtype",
|
@pytest.mark.parametrize("kv_cache_dtype",
|
||||||
["auto"]) # You can also test "fp8" if needed.
|
["auto"]) # You can also test "fp8" if needed.
|
||||||
@pytest.mark.parametrize("align_cache", [True, False])
|
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
|
def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
|
||||||
num_blocks, max_seq_len, batch_size, dtype,
|
num_blocks, max_seq_len, batch_size, dtype,
|
||||||
kv_cache_dtype, align_cache, device):
|
kv_cache_dtype, device):
|
||||||
entry_size = kv_lora_rank + qk_rope_head_dim
|
entry_size = kv_lora_rank + qk_rope_head_dim
|
||||||
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
||||||
kv_cache_dtype, device, align_cache)
|
kv_cache_dtype, device)
|
||||||
_fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)
|
_fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)
|
||||||
|
|
||||||
seq_len_tensor = torch.randint(0,
|
seq_len_tensor = torch.randint(0,
|
||||||
|
|||||||
10
vllm/envs.py
10
vllm/envs.py
@ -84,7 +84,6 @@ if TYPE_CHECKING:
|
|||||||
VLLM_SERVER_DEV_MODE: bool = False
|
VLLM_SERVER_DEV_MODE: bool = False
|
||||||
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128
|
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128
|
||||||
VLLM_MLA_DISABLE: bool = False
|
VLLM_MLA_DISABLE: bool = False
|
||||||
VLLM_MLA_CUDA_MEM_ALIGN_KV_CACHE: bool = True
|
|
||||||
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False
|
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False
|
||||||
VLLM_RAY_PER_WORKER_GPUS: float = 1.0
|
VLLM_RAY_PER_WORKER_GPUS: float = 1.0
|
||||||
VLLM_RAY_BUNDLE_INDICES: str = ""
|
VLLM_RAY_BUNDLE_INDICES: str = ""
|
||||||
@ -580,15 +579,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_RAY_BUNDLE_INDICES":
|
"VLLM_RAY_BUNDLE_INDICES":
|
||||||
lambda: os.getenv("VLLM_RAY_BUNDLE_INDICES", ""),
|
lambda: os.getenv("VLLM_RAY_BUNDLE_INDICES", ""),
|
||||||
|
|
||||||
# When on a Nvidia GPU aligns single entries (within a page) so they are 256
|
|
||||||
# byte aligned for better performance, this increases the memory usage of
|
|
||||||
# the cache. Currently this only affects MLA that results in non-256
|
|
||||||
# byte aligned entries. This matches the alignment the CUDA runtime uses
|
|
||||||
# for all allocations. Currently this primarily affects MLA, for most other
|
|
||||||
# models the alignment is already naturally aligned to 256 bytes.
|
|
||||||
"VLLM_CUDA_MEM_ALIGN_KV_CACHE":
|
|
||||||
lambda: bool(int(os.getenv("VLLM_CUDA_MEM_ALIGN_KV_CACHE", "1"))),
|
|
||||||
|
|
||||||
# In some system, find_loaded_library() may not work. So we allow users to
|
# In some system, find_loaded_library() may not work. So we allow users to
|
||||||
# specify the path through environment variable VLLM_CUDART_SO_PATH.
|
# specify the path through environment variable VLLM_CUDART_SO_PATH.
|
||||||
"VLLM_CUDART_SO_PATH":
|
"VLLM_CUDART_SO_PATH":
|
||||||
|
|||||||
@ -827,12 +827,6 @@ def get_dtype_size(dtype: torch.dtype) -> int:
|
|||||||
return torch.tensor([], dtype=dtype).element_size()
|
return torch.tensor([], dtype=dtype).element_size()
|
||||||
|
|
||||||
|
|
||||||
def align_to_256bytes(extent: int, dtype: torch.dtype) -> int:
|
|
||||||
dtype_size = get_dtype_size(dtype)
|
|
||||||
eles_per_256bytes = 256 // dtype_size
|
|
||||||
return round_up(extent, eles_per_256bytes)
|
|
||||||
|
|
||||||
|
|
||||||
# `collections` helpers
|
# `collections` helpers
|
||||||
def is_list_of(
|
def is_list_of(
|
||||||
value: object,
|
value: object,
|
||||||
|
|||||||
@ -1,18 +1,14 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
"""CacheEngine class for managing the KV cache."""
|
"""CacheEngine class for managing the KV cache."""
|
||||||
from math import prod
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import envs
|
|
||||||
from vllm.attention import get_attn_backend
|
from vllm.attention import get_attn_backend
|
||||||
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig
|
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType,
|
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType,
|
||||||
align_to_256bytes, get_dtype_size,
|
get_dtype_size, is_pin_memory_available)
|
||||||
is_pin_memory_available)
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -42,7 +38,6 @@ class CacheEngine:
|
|||||||
self.num_attention_layers = model_config.get_num_layers_by_block_type(
|
self.num_attention_layers = model_config.get_num_layers_by_block_type(
|
||||||
parallel_config, LayerBlockType.attention)
|
parallel_config, LayerBlockType.attention)
|
||||||
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
|
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
|
||||||
self.align_cache = self._align_cache(model_config)
|
|
||||||
|
|
||||||
self.block_size = cache_config.block_size
|
self.block_size = cache_config.block_size
|
||||||
self.num_gpu_blocks = cache_config.num_gpu_blocks
|
self.num_gpu_blocks = cache_config.num_gpu_blocks
|
||||||
@ -81,38 +76,18 @@ class CacheEngine:
|
|||||||
pin_memory = is_pin_memory_available() if device == "cpu" else False
|
pin_memory = is_pin_memory_available() if device == "cpu" else False
|
||||||
kv_cache: List[torch.Tensor] = []
|
kv_cache: List[torch.Tensor] = []
|
||||||
|
|
||||||
# Align entries so they are 256 byte aligned for better performance
|
|
||||||
# Primarily targets MLA as this typically only ends up having entries
|
|
||||||
# be 128 byte aligned.
|
|
||||||
if self.align_cache:
|
|
||||||
# We assume the cache shape is:
|
|
||||||
# (TOTAL_PAGES, PAGE_SIZE, entry_shape...)
|
|
||||||
# NOTE this assumption currently only holds for MLA so we only apply
|
|
||||||
# this optimization when `use_mla` is true
|
|
||||||
entry_shape = kv_cache_shape[2:]
|
|
||||||
entry_size = prod(entry_shape)
|
|
||||||
alloc_entry_size = align_to_256bytes(entry_size, self.dtype)
|
|
||||||
alloc_shape = (*kv_cache_shape[:2], alloc_entry_size)
|
|
||||||
else:
|
|
||||||
alloc_shape = kv_cache_shape
|
|
||||||
|
|
||||||
for _ in range(self.num_attention_layers):
|
for _ in range(self.num_attention_layers):
|
||||||
# null block in CpuGpuBlockAllocator requires at least that
|
# null block in CpuGpuBlockAllocator requires at least that
|
||||||
# block to be zeroed-out.
|
# block to be zeroed-out.
|
||||||
# We zero-out everything for simplicity.
|
# We zero-out everything for simplicity.
|
||||||
layer_kv_cache = torch.zeros(alloc_shape,
|
layer_kv_cache = torch.zeros(kv_cache_shape,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
device=device)
|
device=device)
|
||||||
|
|
||||||
# If we allocated with padding for alignment reasons truncate the
|
|
||||||
# shape while preserving the aligned stride
|
|
||||||
if self.align_cache:
|
|
||||||
layer_kv_cache = layer_kv_cache[..., :entry_size]
|
|
||||||
|
|
||||||
# view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases
|
# view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases
|
||||||
# when entry_shape is higher than 1D
|
# when entry_shape is higher than 1D
|
||||||
kv_cache.append(layer_kv_cache.view(kv_cache_shape))
|
kv_cache.append(layer_kv_cache)
|
||||||
return kv_cache
|
return kv_cache
|
||||||
|
|
||||||
def swap_in(self, src_to_dst: torch.Tensor) -> None:
|
def swap_in(self, src_to_dst: torch.Tensor) -> None:
|
||||||
@ -128,14 +103,6 @@ class CacheEngine:
|
|||||||
def copy(self, src_to_dsts: torch.Tensor) -> None:
|
def copy(self, src_to_dsts: torch.Tensor) -> None:
|
||||||
self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts)
|
self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _align_cache(model_config: ModelConfig):
|
|
||||||
# Currently align_cache only applies to MLA models since the other
|
|
||||||
# cache kernels haven't been updated yet to support non-continguous
|
|
||||||
# tensors
|
|
||||||
return model_config.use_mla and current_platform.is_cuda() \
|
|
||||||
and envs.VLLM_CUDA_MEM_ALIGN_KV_CACHE
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_cache_block_size(
|
def get_cache_block_size(
|
||||||
cache_config: CacheConfig,
|
cache_config: CacheConfig,
|
||||||
@ -153,9 +120,6 @@ class CacheEngine:
|
|||||||
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
||||||
|
|
||||||
key_cache_entry = num_heads * head_size
|
key_cache_entry = num_heads * head_size
|
||||||
if CacheEngine._align_cache(model_config):
|
|
||||||
key_cache_entry = align_to_256bytes(key_cache_entry,
|
|
||||||
model_config.dtype)
|
|
||||||
|
|
||||||
# For MLA there is no value cache, since the latent vector
|
# For MLA there is no value cache, since the latent vector
|
||||||
# is joint keys and values.
|
# is joint keys and values.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user