[v1] Add cross-attention KV cache support for encoder-decoder models (#23664)

Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
Russell Bryant 2025-08-26 14:49:06 -04:00 committed by GitHub
parent 227e231b55
commit 98aa16ff41
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 153 additions and 14 deletions

View File

@ -372,3 +372,22 @@ class MultiModalRegistry:
)
return dummy_data
def get_encdec_max_encoder_len(self, model_config: "ModelConfig") -> int:
"""
Get the maximum length of the encoder input for encoder-decoder models.
"""
if not model_config.is_encoder_decoder:
return 0
max_tokens = self.\
get_max_tokens_per_item_by_nonzero_modality(model_config)
if not max_tokens:
# TODO - this function assumes encoder-decoder models are
# multimodal. This will need to change when adding support for more
# than whisper.
return 0
assert len(max_tokens) == 1, "Encoder-decoder models are expected \
to implement the multimodal interface with at most one modality."
first_modality = next(iter(max_tokens))
return max_tokens[first_modality]

View File

@ -6,7 +6,7 @@ from typing import Optional
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
from vllm.v1.core.single_type_kv_cache_manager import (
FullAttentionManager, get_manager_for_kv_cache_spec)
CrossAttentionManager, FullAttentionManager, get_manager_for_kv_cache_spec)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.request import Request
@ -42,9 +42,10 @@ class KVCacheCoordinator(ABC):
) for i, kv_cache_group in enumerate(
self.kv_cache_config.kv_cache_groups))
def get_num_blocks_to_allocate(
self, request_id: str, num_tokens: int,
new_computed_blocks: tuple[list[KVCacheBlock], ...]) -> int:
def get_num_blocks_to_allocate(self, request_id: str, num_tokens: int,
new_computed_blocks: tuple[
list[KVCacheBlock], ...],
num_encoder_tokens: int) -> int:
"""
Get the number of blocks needed to be allocated for the request.
@ -54,14 +55,22 @@ class KVCacheCoordinator(ABC):
tokens that are already allocated).
new_computed_blocks: The new computed blocks just hitting the
prefix caching.
num_encoder_tokens: The number of encoder tokens for allocating
blocks for cross-attention.
Returns:
The number of blocks.
"""
num_blocks_to_allocate = 0
for i, manager in enumerate(self.single_type_managers):
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
request_id, num_tokens, new_computed_blocks[i])
if isinstance(manager, CrossAttentionManager):
# For cross-attention, we issue a single static allocation
# of blocks based on the number of encoder input tokens.
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
request_id, num_encoder_tokens, [])
else:
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
request_id, num_tokens, new_computed_blocks[i])
return num_blocks_to_allocate
def save_new_computed_blocks(
@ -79,8 +88,11 @@ class KVCacheCoordinator(ABC):
manager.save_new_computed_blocks(request_id,
new_computed_blocks[i])
def allocate_new_blocks(self, request_id: str,
num_tokens: int) -> tuple[list[KVCacheBlock], ...]:
def allocate_new_blocks(
self,
request_id: str,
num_tokens: int,
num_encoder_tokens: int = 0) -> tuple[list[KVCacheBlock], ...]:
"""
Allocate new blocks for the request to give it at least `num_tokens`
token slots.
@ -89,12 +101,16 @@ class KVCacheCoordinator(ABC):
request_id: The request ID.
num_tokens: The total number of tokens that need a slot (including
tokens that are already allocated).
num_encoder_tokens: The number of encoder tokens for allocating
blocks for cross-attention.
Returns:
The new allocated blocks.
"""
return tuple(
manager.allocate_new_blocks(request_id, num_tokens)
manager.allocate_new_blocks(
request_id, num_encoder_tokens if isinstance(
manager, CrossAttentionManager) else num_tokens)
for manager in self.single_type_managers)
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:

View File

@ -187,6 +187,7 @@ class KVCacheManager:
new_computed_blocks: Optional[KVCacheBlocks] = None,
num_lookahead_tokens: int = 0,
delay_cache_blocks: bool = False,
num_encoder_tokens: int = 0,
) -> Optional[KVCacheBlocks]:
"""Add slots for a request with new tokens to append.
@ -253,6 +254,7 @@ class KVCacheManager:
request_id=request.request_id,
num_tokens=num_tokens_need_slot,
new_computed_blocks=new_computed_block_list,
num_encoder_tokens=num_encoder_tokens,
)
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
@ -273,7 +275,7 @@ class KVCacheManager:
new_computed_block_list)
new_blocks = self.coordinator.allocate_new_blocks(
request.request_id, num_tokens_need_slot)
request.request_id, num_tokens_need_slot, num_encoder_tokens)
# P/D: delay caching blocks if we have to recv from
# remote. Update state for locally cached blocks.
@ -292,7 +294,7 @@ class KVCacheManager:
def free(self, request: Request) -> None:
"""Free the blocks allocated for the request.
We free the blocks in reverse order so that he tail blocks are evicted
We free the blocks in reverse order so that the tail blocks are evicted
first when caching is enabled.
Args:

View File

@ -58,6 +58,7 @@ class Scheduler(SchedulerInterface):
self.parallel_config = vllm_config.parallel_config
self.log_stats = log_stats
self.structured_output_manager = structured_output_manager
self.is_encoder_decoder = vllm_config.model_config.is_encoder_decoder
# include_finished_set controls whether a separate set of finished
# request ids should be included in the EngineCoreOutputs returned
@ -83,6 +84,9 @@ class Scheduler(SchedulerInterface):
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
"Multiple KV cache groups are not currently supported "
"with KV connectors")
assert not self.is_encoder_decoder, (
"Encoder-decoder models are not currently supported "
"with KV connectors")
self.connector = KVConnectorFactory.create_connector(
config=self.vllm_config, role=KVConnectorRole.SCHEDULER)
@ -431,6 +435,22 @@ class Scheduler(SchedulerInterface):
== 0 else
self.num_lookahead_tokens)
# Determine if we need to allocate cross-attention blocks.
if self.is_encoder_decoder and request.has_encoder_inputs:
# TODO(russellb): For Whisper, we know that the input is
# always padded to the maximum length. If we support other
# encoder-decoder models, this will need to be updated if we
# want to only allocate what is needed.
assert ("whisper"
in self.vllm_config.model_config.model.lower()), (
"Whisper is the only supported "
"encoder-decoder model.")
num_encoder_tokens = MULTIMODAL_REGISTRY.\
get_encdec_max_encoder_len(
self.vllm_config.model_config)
else:
num_encoder_tokens = 0
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens + num_external_computed_tokens,
@ -438,6 +458,7 @@ class Scheduler(SchedulerInterface):
new_computed_blocks,
num_lookahead_tokens=effective_lookahead_tokens,
delay_cache_blocks=load_kv_async,
num_encoder_tokens=num_encoder_tokens,
)
if new_blocks is None:
@ -703,7 +724,21 @@ class Scheduler(SchedulerInterface):
# The encoder input is not needed in this step.
break
if start_pos + num_encoder_tokens <= num_computed_tokens:
if self.is_encoder_decoder and num_computed_tokens > 0:
assert start_pos == 0, (
"Encoder input should be processed at the beginning of "
"the sequence when encoder-decoder models are used.")
# Encoder input has already been computed
# The calculation here is a bit different. We don't turn encoder
# output into tokens that get processed by the decoder and
# reflected in num_computed_tokens. Instead, start_pos reflects
# the position where we need to ensure we calculate encoder
# inputs. This should always be 0 to ensure we calculate encoder
# inputs before running the decoder. Once we've calculated some
# decoder tokens (num_computed_tokens > 0), then we know we
# already calculated encoder inputs and can skip here.
continue
elif start_pos + num_encoder_tokens <= num_computed_tokens:
# The encoder input is already computed and stored
# in the decoder's KV cache.
continue

View File

@ -8,8 +8,9 @@ from vllm.utils import cdiv
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
FullAttentionSpec, KVCacheSpec,
MambaSpec, SlidingWindowSpec)
CrossAttentionSpec, FullAttentionSpec,
KVCacheSpec, MambaSpec,
SlidingWindowSpec)
from vllm.v1.request import Request
@ -552,11 +553,62 @@ class MambaManager(SingleTypeKVCacheManager):
return new_blocks
class CrossAttentionManager(SingleTypeKVCacheManager):
"""Manager for cross-attention KV cache in encoder-decoder models."""
def save_new_computed_blocks(
self, request_id: str,
new_computed_blocks: list[KVCacheBlock]) -> None:
# We do not cache blocks for cross-attention to be shared between
# requests, so `new_computed_blocks` should always be empty.
assert len(new_computed_blocks) == 0
def cache_blocks(self, request: Request, num_tokens: int) -> None:
# We do not cache blocks for cross-attention to be shared between
# requests, so this method is not relevant.
raise ValueError("Should not be called as prefix caching is disabled.")
def get_num_common_prefix_blocks(self, request_id: str,
num_running_requests: int) -> int:
# Cross-attention blocks contain request-specific encoder states
# and are not shared between different requests
return 0
@classmethod
def find_longest_cache_hit(
cls,
block_hashes: list[BlockHash],
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(kv_cache_spec, CrossAttentionSpec), (
"CrossAttentionManager can only be used for cross-attention groups"
)
# Cross-attention does not benefit from prefix caching since:
# 1. Encoder states are unique per request (different audio/image
# inputs)
# 2. Encoder states are computed once per request, not incrementally
# 3. No reusable prefix exists between different multimodal inputs
# Return empty blocks to indicate no cache hits
raise NotImplementedError(
"CrossAttentionManager does not support caching")
def remove_skipped_blocks(self, request_id: str,
num_computed_tokens: int) -> None:
# Cross-attention blocks represent encoder states which are needed
# for the entire decoding process, so no blocks should be skipped
pass
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
FullAttentionSpec: FullAttentionManager,
SlidingWindowSpec: SlidingWindowManager,
ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager,
MambaSpec: MambaManager,
CrossAttentionSpec: CrossAttentionManager,
}

View File

@ -11,6 +11,7 @@ from typing_extensions import Self
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.utils import cdiv, get_dtype_size
logger = init_logger(__name__)
@ -211,6 +212,20 @@ class EncoderOnlyAttentionSpec(AttentionSpec):
return 0
@dataclass(frozen=True)
class CrossAttentionSpec(AttentionSpec):
"""
KV cache spec for cross-attention layers in encoder-decoder models.
"""
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
# For cross-attention, we need to cache encoder states
# Get encoder length (e.g., 1500 for Whisper).
max_encoder_len = MULTIMODAL_REGISTRY.\
get_encdec_max_encoder_len(vllm_config.model_config)
return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes
@dataclass
class KVCacheTensor:
"""