[Core] Reuse empty block lists whenever possible in KVCacheBlocks to mitigate GC costs (#24964)

Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
This commit is contained in:
Jialin Ouyang 2025-10-14 12:58:43 -07:00 committed by GitHub
parent 82af928c41
commit acaa2c0a4a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 53 additions and 26 deletions

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable from collections.abc import Iterable, Sequence
from typing import Any from typing import Any
from vllm.distributed.kv_events import ( from vllm.distributed.kv_events import (
@ -328,7 +328,7 @@ class BlockPool:
) )
return True return True
def touch(self, blocks: tuple[list[KVCacheBlock], ...]) -> None: def touch(self, blocks: tuple[Sequence[KVCacheBlock], ...]) -> None:
"""Touch a block increases its reference count by 1, and may remove """Touch a block increases its reference count by 1, and may remove
the block from the free queue. This is used when a block is hit by the block from the free queue. This is used when a block is hit by
another request with the same prefix. another request with the same prefix.

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence
from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
@ -51,7 +52,7 @@ class KVCacheCoordinator(ABC):
self, self,
request_id: str, request_id: str,
num_tokens: int, num_tokens: int,
new_computed_blocks: tuple[list[KVCacheBlock], ...], new_computed_blocks: tuple[Sequence[KVCacheBlock], ...],
num_encoder_tokens: int, num_encoder_tokens: int,
) -> int: ) -> int:
""" """
@ -84,7 +85,7 @@ class KVCacheCoordinator(ABC):
return num_blocks_to_allocate return num_blocks_to_allocate
def save_new_computed_blocks( def save_new_computed_blocks(
self, request_id: str, new_computed_blocks: tuple[list[KVCacheBlock], ...] self, request_id: str, new_computed_blocks: tuple[Sequence[KVCacheBlock], ...]
) -> None: ) -> None:
""" """
Add the new computed blocks to the request. Add the new computed blocks to the request.

View File

@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import Literal, overload from typing import Literal, overload
@ -23,7 +25,7 @@ class KVCacheBlocks:
structure from the Scheduler. structure from the Scheduler.
""" """
blocks: tuple[list[KVCacheBlock], ...] blocks: tuple[Sequence[KVCacheBlock], ...]
""" """
`blocks[i][j]` refers to the i-th kv_cache_group `blocks[i][j]` refers to the i-th kv_cache_group
and the j-th block of tokens.We don't use block of and the j-th block of tokens.We don't use block of
@ -31,12 +33,20 @@ class KVCacheBlocks:
kv_cache_groups have the same number of blocks, which is true for now but kv_cache_groups have the same number of blocks, which is true for now but
will be broken if we want to give different block_size to different will be broken if we want to give different block_size to different
kv_cache_groups in the future. kv_cache_groups in the future.
Each single type KVCacheBlocks could be represented as:
- list[KVCacheBlock] for more than one KVCacheBlock
- an empty tuple for requests without KVCacheBlock
(a precomputed KVCacheBlocks is in KVCacheManager to avoid GC overhead)
""" """
def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks": def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks":
"""Adds two KVCacheBlocks instances.""" """Adds two KVCacheBlocks instances."""
return KVCacheBlocks( return KVCacheBlocks(
tuple(blk1 + blk2 for blk1, blk2 in zip(self.blocks, other.blocks)) tuple(
list(itertools.chain(blk1, blk2))
for blk1, blk2 in zip(self.blocks, other.blocks)
)
) )
@overload @overload
@ -74,8 +84,10 @@ class KVCacheBlocks:
return [block.block_id for block in self.blocks[0] if block.block_hash is None] return [block.block_id for block in self.blocks[0] if block.block_hash is None]
def new_empty(self) -> "KVCacheBlocks": def new_empty(self) -> "KVCacheBlocks":
"""Creates a new KVCacheBlocks instance with no blocks.""" """
return KVCacheBlocks(tuple([] for _ in range(len(self.blocks)))) Creates a new KVCacheBlocks instance with no blocks.
"""
return KVCacheBlocks(tuple(() for _ in range(len(self.blocks))))
class KVCacheManager: class KVCacheManager:
@ -131,6 +143,15 @@ class KVCacheManager:
self.block_pool = self.coordinator.block_pool self.block_pool = self.coordinator.block_pool
self.kv_cache_config = kv_cache_config self.kv_cache_config = kv_cache_config
# Pre-constructed KVCacheBlocks with no blocks, callers should use this
# via create_kv_cache_blocks instead of creating new ones to avoid GC
# overhead.
#
# We use nested tuples to ensure the empty KVCacheBlocks is immutable.
self.empty_kv_cache_blocks = KVCacheBlocks(
tuple(() for _ in range(self.num_kv_cache_groups))
)
@property @property
def usage(self) -> float: def usage(self) -> float:
"""Get the KV cache usage. """Get the KV cache usage.
@ -170,7 +191,7 @@ class KVCacheManager:
request.sampling_params is not None request.sampling_params is not None
and request.sampling_params.prompt_logprobs is not None and request.sampling_params.prompt_logprobs is not None
): ):
return self.create_empty_block_list(), 0 return self.empty_kv_cache_blocks, 0
# NOTE: When all tokens hit the cache, we must recompute the last token # NOTE: When all tokens hit the cache, we must recompute the last token
# to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1. # to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1.
@ -198,7 +219,7 @@ class KVCacheManager:
self.prefix_cache_stats.queries += request.num_tokens self.prefix_cache_stats.queries += request.num_tokens
self.prefix_cache_stats.hits += num_new_computed_tokens self.prefix_cache_stats.hits += num_new_computed_tokens
return KVCacheBlocks(computed_blocks), num_new_computed_tokens return (self.create_kv_cache_blocks(computed_blocks), num_new_computed_tokens)
def allocate_slots( def allocate_slots(
self, self,
@ -251,9 +272,7 @@ class KVCacheManager:
if new_computed_blocks is not None: if new_computed_blocks is not None:
new_computed_block_list = new_computed_blocks.blocks new_computed_block_list = new_computed_blocks.blocks
else: else:
new_computed_block_list = tuple( new_computed_block_list = self.empty_kv_cache_blocks.blocks
[] for _ in range(len(self.kv_cache_config.kv_cache_groups))
)
# Free the blocks that are skipped during the attention computation # Free the blocks that are skipped during the attention computation
# (e.g., tokens outside the sliding window). # (e.g., tokens outside the sliding window).
@ -305,7 +324,7 @@ class KVCacheManager:
# P/D: delay caching blocks if we have to recv from # P/D: delay caching blocks if we have to recv from
# remote. Update state for locally cached blocks. # remote. Update state for locally cached blocks.
if not self.enable_caching or delay_cache_blocks: if not self.enable_caching or delay_cache_blocks:
return KVCacheBlocks(new_blocks) return self.create_kv_cache_blocks(new_blocks)
# NOTE(woosuk): We want to commit (cache) up to num_computed_tokens + # NOTE(woosuk): We want to commit (cache) up to num_computed_tokens +
# num_new_tokens, but must exclude "non-committable" tokens (e.g., # num_new_tokens, but must exclude "non-committable" tokens (e.g.,
@ -316,7 +335,7 @@ class KVCacheManager:
) )
self.coordinator.cache_blocks(request, num_tokens_to_cache) self.coordinator.cache_blocks(request, num_tokens_to_cache)
return KVCacheBlocks(new_blocks) return self.create_kv_cache_blocks(new_blocks)
def free(self, request: Request) -> None: def free(self, request: Request) -> None:
"""Free the blocks allocated for the request. """Free the blocks allocated for the request.
@ -388,7 +407,7 @@ class KVCacheManager:
def get_blocks(self, request_id: str) -> KVCacheBlocks: def get_blocks(self, request_id: str) -> KVCacheBlocks:
"""Get the blocks of a request.""" """Get the blocks of a request."""
return KVCacheBlocks(self.coordinator.get_blocks(request_id)) return self.create_kv_cache_blocks(self.coordinator.get_blocks(request_id))
def get_block_ids(self, request_id: str) -> tuple[list[int], ...]: def get_block_ids(self, request_id: str) -> tuple[list[int], ...]:
"""Get the block ids of a request.""" """Get the block ids of a request."""
@ -399,6 +418,8 @@ class KVCacheManager:
if self.enable_caching: if self.enable_caching:
self.coordinator.cache_blocks(request, num_computed_tokens) self.coordinator.cache_blocks(request, num_computed_tokens)
def create_empty_block_list(self) -> KVCacheBlocks: def create_kv_cache_blocks(
"""Creates a new KVCacheBlocks instance with no blocks.""" self, blocks: tuple[list[KVCacheBlock], ...]
return KVCacheBlocks(tuple([] for _ in range(self.num_kv_cache_groups))) ) -> KVCacheBlocks:
# Only create new KVCacheBlocks for non-empty blocks
return KVCacheBlocks(blocks) if any(blocks) else self.empty_kv_cache_blocks

View File

@ -421,9 +421,7 @@ class Scheduler(SchedulerInterface):
# KVTransfer: WAITING reqs have num_computed_tokens > 0 # KVTransfer: WAITING reqs have num_computed_tokens > 0
# after async KV recvs are completed. # after async KV recvs are completed.
else: else:
new_computed_blocks = ( new_computed_blocks = self.kv_cache_manager.empty_kv_cache_blocks
self.kv_cache_manager.create_empty_block_list()
)
num_new_local_computed_tokens = 0 num_new_local_computed_tokens = 0
num_computed_tokens = request.num_computed_tokens num_computed_tokens = request.num_computed_tokens

View File

@ -3,6 +3,7 @@
import itertools import itertools
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
from collections.abc import Sequence
from vllm.utils import cdiv from vllm.utils import cdiv
from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.block_pool import BlockPool
@ -61,7 +62,10 @@ class SingleTypeKVCacheManager(ABC):
self._null_block = block_pool.null_block self._null_block = block_pool.null_block
def get_num_blocks_to_allocate( def get_num_blocks_to_allocate(
self, request_id: str, num_tokens: int, new_computed_blocks: list[KVCacheBlock] self,
request_id: str,
num_tokens: int,
new_computed_blocks: Sequence[KVCacheBlock],
) -> int: ) -> int:
""" """
Get the number of blocks needed to be allocated for the request. Get the number of blocks needed to be allocated for the request.
@ -93,7 +97,7 @@ class SingleTypeKVCacheManager(ABC):
return num_new_blocks + num_evictable_computed_blocks return num_new_blocks + num_evictable_computed_blocks
def save_new_computed_blocks( def save_new_computed_blocks(
self, request_id: str, new_computed_blocks: list[KVCacheBlock] self, request_id: str, new_computed_blocks: Sequence[KVCacheBlock]
) -> None: ) -> None:
""" """
Add the new computed blocks to the request. Add the new computed blocks to the request.
@ -593,7 +597,10 @@ class MambaManager(SingleTypeKVCacheManager):
return 0 return 0
def get_num_blocks_to_allocate( def get_num_blocks_to_allocate(
self, request_id: str, num_tokens: int, new_computed_blocks: list[KVCacheBlock] self,
request_id: str,
num_tokens: int,
new_computed_blocks: Sequence[KVCacheBlock],
) -> int: ) -> int:
# Allocate extra `num_speculative_blocks` blocks for # Allocate extra `num_speculative_blocks` blocks for
# speculative decoding (MTP/EAGLE) with linear attention. # speculative decoding (MTP/EAGLE) with linear attention.
@ -625,7 +632,7 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
"""Manager for cross-attention KV cache in encoder-decoder models.""" """Manager for cross-attention KV cache in encoder-decoder models."""
def save_new_computed_blocks( def save_new_computed_blocks(
self, request_id: str, new_computed_blocks: list[KVCacheBlock] self, request_id: str, new_computed_blocks: Sequence[KVCacheBlock]
) -> None: ) -> None:
# We do not cache blocks for cross-attention to be shared between # We do not cache blocks for cross-attention to be shared between
# requests, so `new_computed_blocks` should always be empty. # requests, so `new_computed_blocks` should always be empty.