mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 18:25:40 +08:00
[v1] Add cross-attention KV cache support for encoder-decoder models (#23664)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
parent
227e231b55
commit
98aa16ff41
@ -372,3 +372,22 @@ class MultiModalRegistry:
|
||||
)
|
||||
|
||||
return dummy_data
|
||||
|
||||
def get_encdec_max_encoder_len(self, model_config: "ModelConfig") -> int:
|
||||
"""
|
||||
Get the maximum length of the encoder input for encoder-decoder models.
|
||||
"""
|
||||
if not model_config.is_encoder_decoder:
|
||||
return 0
|
||||
max_tokens = self.\
|
||||
get_max_tokens_per_item_by_nonzero_modality(model_config)
|
||||
if not max_tokens:
|
||||
# TODO - this function assumes encoder-decoder models are
|
||||
# multimodal. This will need to change when adding support for more
|
||||
# than whisper.
|
||||
return 0
|
||||
assert len(max_tokens) == 1, "Encoder-decoder models are expected \
|
||||
to implement the multimodal interface with at most one modality."
|
||||
|
||||
first_modality = next(iter(max_tokens))
|
||||
return max_tokens[first_modality]
|
||||
|
||||
@ -6,7 +6,7 @@ from typing import Optional
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
|
||||
from vllm.v1.core.single_type_kv_cache_manager import (
|
||||
FullAttentionManager, get_manager_for_kv_cache_spec)
|
||||
CrossAttentionManager, FullAttentionManager, get_manager_for_kv_cache_spec)
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec)
|
||||
from vllm.v1.request import Request
|
||||
@ -42,9 +42,10 @@ class KVCacheCoordinator(ABC):
|
||||
) for i, kv_cache_group in enumerate(
|
||||
self.kv_cache_config.kv_cache_groups))
|
||||
|
||||
def get_num_blocks_to_allocate(
|
||||
self, request_id: str, num_tokens: int,
|
||||
new_computed_blocks: tuple[list[KVCacheBlock], ...]) -> int:
|
||||
def get_num_blocks_to_allocate(self, request_id: str, num_tokens: int,
|
||||
new_computed_blocks: tuple[
|
||||
list[KVCacheBlock], ...],
|
||||
num_encoder_tokens: int) -> int:
|
||||
"""
|
||||
Get the number of blocks needed to be allocated for the request.
|
||||
|
||||
@ -54,14 +55,22 @@ class KVCacheCoordinator(ABC):
|
||||
tokens that are already allocated).
|
||||
new_computed_blocks: The new computed blocks just hitting the
|
||||
prefix caching.
|
||||
num_encoder_tokens: The number of encoder tokens for allocating
|
||||
blocks for cross-attention.
|
||||
|
||||
Returns:
|
||||
The number of blocks.
|
||||
"""
|
||||
num_blocks_to_allocate = 0
|
||||
for i, manager in enumerate(self.single_type_managers):
|
||||
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
|
||||
request_id, num_tokens, new_computed_blocks[i])
|
||||
if isinstance(manager, CrossAttentionManager):
|
||||
# For cross-attention, we issue a single static allocation
|
||||
# of blocks based on the number of encoder input tokens.
|
||||
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
|
||||
request_id, num_encoder_tokens, [])
|
||||
else:
|
||||
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
|
||||
request_id, num_tokens, new_computed_blocks[i])
|
||||
return num_blocks_to_allocate
|
||||
|
||||
def save_new_computed_blocks(
|
||||
@ -79,8 +88,11 @@ class KVCacheCoordinator(ABC):
|
||||
manager.save_new_computed_blocks(request_id,
|
||||
new_computed_blocks[i])
|
||||
|
||||
def allocate_new_blocks(self, request_id: str,
|
||||
num_tokens: int) -> tuple[list[KVCacheBlock], ...]:
|
||||
def allocate_new_blocks(
|
||||
self,
|
||||
request_id: str,
|
||||
num_tokens: int,
|
||||
num_encoder_tokens: int = 0) -> tuple[list[KVCacheBlock], ...]:
|
||||
"""
|
||||
Allocate new blocks for the request to give it at least `num_tokens`
|
||||
token slots.
|
||||
@ -89,12 +101,16 @@ class KVCacheCoordinator(ABC):
|
||||
request_id: The request ID.
|
||||
num_tokens: The total number of tokens that need a slot (including
|
||||
tokens that are already allocated).
|
||||
num_encoder_tokens: The number of encoder tokens for allocating
|
||||
blocks for cross-attention.
|
||||
|
||||
Returns:
|
||||
The new allocated blocks.
|
||||
"""
|
||||
return tuple(
|
||||
manager.allocate_new_blocks(request_id, num_tokens)
|
||||
manager.allocate_new_blocks(
|
||||
request_id, num_encoder_tokens if isinstance(
|
||||
manager, CrossAttentionManager) else num_tokens)
|
||||
for manager in self.single_type_managers)
|
||||
|
||||
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
|
||||
|
||||
@ -187,6 +187,7 @@ class KVCacheManager:
|
||||
new_computed_blocks: Optional[KVCacheBlocks] = None,
|
||||
num_lookahead_tokens: int = 0,
|
||||
delay_cache_blocks: bool = False,
|
||||
num_encoder_tokens: int = 0,
|
||||
) -> Optional[KVCacheBlocks]:
|
||||
"""Add slots for a request with new tokens to append.
|
||||
|
||||
@ -253,6 +254,7 @@ class KVCacheManager:
|
||||
request_id=request.request_id,
|
||||
num_tokens=num_tokens_need_slot,
|
||||
new_computed_blocks=new_computed_block_list,
|
||||
num_encoder_tokens=num_encoder_tokens,
|
||||
)
|
||||
|
||||
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
|
||||
@ -273,7 +275,7 @@ class KVCacheManager:
|
||||
new_computed_block_list)
|
||||
|
||||
new_blocks = self.coordinator.allocate_new_blocks(
|
||||
request.request_id, num_tokens_need_slot)
|
||||
request.request_id, num_tokens_need_slot, num_encoder_tokens)
|
||||
|
||||
# P/D: delay caching blocks if we have to recv from
|
||||
# remote. Update state for locally cached blocks.
|
||||
@ -292,7 +294,7 @@ class KVCacheManager:
|
||||
|
||||
def free(self, request: Request) -> None:
|
||||
"""Free the blocks allocated for the request.
|
||||
We free the blocks in reverse order so that he tail blocks are evicted
|
||||
We free the blocks in reverse order so that the tail blocks are evicted
|
||||
first when caching is enabled.
|
||||
|
||||
Args:
|
||||
|
||||
@ -58,6 +58,7 @@ class Scheduler(SchedulerInterface):
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.log_stats = log_stats
|
||||
self.structured_output_manager = structured_output_manager
|
||||
self.is_encoder_decoder = vllm_config.model_config.is_encoder_decoder
|
||||
|
||||
# include_finished_set controls whether a separate set of finished
|
||||
# request ids should be included in the EngineCoreOutputs returned
|
||||
@ -83,6 +84,9 @@ class Scheduler(SchedulerInterface):
|
||||
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
|
||||
"Multiple KV cache groups are not currently supported "
|
||||
"with KV connectors")
|
||||
assert not self.is_encoder_decoder, (
|
||||
"Encoder-decoder models are not currently supported "
|
||||
"with KV connectors")
|
||||
self.connector = KVConnectorFactory.create_connector(
|
||||
config=self.vllm_config, role=KVConnectorRole.SCHEDULER)
|
||||
|
||||
@ -431,6 +435,22 @@ class Scheduler(SchedulerInterface):
|
||||
== 0 else
|
||||
self.num_lookahead_tokens)
|
||||
|
||||
# Determine if we need to allocate cross-attention blocks.
|
||||
if self.is_encoder_decoder and request.has_encoder_inputs:
|
||||
# TODO(russellb): For Whisper, we know that the input is
|
||||
# always padded to the maximum length. If we support other
|
||||
# encoder-decoder models, this will need to be updated if we
|
||||
# want to only allocate what is needed.
|
||||
assert ("whisper"
|
||||
in self.vllm_config.model_config.model.lower()), (
|
||||
"Whisper is the only supported "
|
||||
"encoder-decoder model.")
|
||||
num_encoder_tokens = MULTIMODAL_REGISTRY.\
|
||||
get_encdec_max_encoder_len(
|
||||
self.vllm_config.model_config)
|
||||
else:
|
||||
num_encoder_tokens = 0
|
||||
|
||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||
request,
|
||||
num_new_tokens + num_external_computed_tokens,
|
||||
@ -438,6 +458,7 @@ class Scheduler(SchedulerInterface):
|
||||
new_computed_blocks,
|
||||
num_lookahead_tokens=effective_lookahead_tokens,
|
||||
delay_cache_blocks=load_kv_async,
|
||||
num_encoder_tokens=num_encoder_tokens,
|
||||
)
|
||||
|
||||
if new_blocks is None:
|
||||
@ -703,7 +724,21 @@ class Scheduler(SchedulerInterface):
|
||||
# The encoder input is not needed in this step.
|
||||
break
|
||||
|
||||
if start_pos + num_encoder_tokens <= num_computed_tokens:
|
||||
if self.is_encoder_decoder and num_computed_tokens > 0:
|
||||
assert start_pos == 0, (
|
||||
"Encoder input should be processed at the beginning of "
|
||||
"the sequence when encoder-decoder models are used.")
|
||||
# Encoder input has already been computed
|
||||
# The calculation here is a bit different. We don't turn encoder
|
||||
# output into tokens that get processed by the decoder and
|
||||
# reflected in num_computed_tokens. Instead, start_pos reflects
|
||||
# the position where we need to ensure we calculate encoder
|
||||
# inputs. This should always be 0 to ensure we calculate encoder
|
||||
# inputs before running the decoder. Once we've calculated some
|
||||
# decoder tokens (num_computed_tokens > 0), then we know we
|
||||
# already calculated encoder inputs and can skip here.
|
||||
continue
|
||||
elif start_pos + num_encoder_tokens <= num_computed_tokens:
|
||||
# The encoder input is already computed and stored
|
||||
# in the decoder's KV cache.
|
||||
continue
|
||||
|
||||
@ -8,8 +8,9 @@ from vllm.utils import cdiv
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
|
||||
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
|
||||
FullAttentionSpec, KVCacheSpec,
|
||||
MambaSpec, SlidingWindowSpec)
|
||||
CrossAttentionSpec, FullAttentionSpec,
|
||||
KVCacheSpec, MambaSpec,
|
||||
SlidingWindowSpec)
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
@ -552,11 +553,62 @@ class MambaManager(SingleTypeKVCacheManager):
|
||||
return new_blocks
|
||||
|
||||
|
||||
class CrossAttentionManager(SingleTypeKVCacheManager):
|
||||
"""Manager for cross-attention KV cache in encoder-decoder models."""
|
||||
|
||||
def save_new_computed_blocks(
|
||||
self, request_id: str,
|
||||
new_computed_blocks: list[KVCacheBlock]) -> None:
|
||||
# We do not cache blocks for cross-attention to be shared between
|
||||
# requests, so `new_computed_blocks` should always be empty.
|
||||
assert len(new_computed_blocks) == 0
|
||||
|
||||
def cache_blocks(self, request: Request, num_tokens: int) -> None:
|
||||
# We do not cache blocks for cross-attention to be shared between
|
||||
# requests, so this method is not relevant.
|
||||
raise ValueError("Should not be called as prefix caching is disabled.")
|
||||
|
||||
def get_num_common_prefix_blocks(self, request_id: str,
|
||||
num_running_requests: int) -> int:
|
||||
# Cross-attention blocks contain request-specific encoder states
|
||||
# and are not shared between different requests
|
||||
return 0
|
||||
|
||||
@classmethod
|
||||
def find_longest_cache_hit(
|
||||
cls,
|
||||
block_hashes: list[BlockHash],
|
||||
max_length: int,
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
assert isinstance(kv_cache_spec, CrossAttentionSpec), (
|
||||
"CrossAttentionManager can only be used for cross-attention groups"
|
||||
)
|
||||
# Cross-attention does not benefit from prefix caching since:
|
||||
# 1. Encoder states are unique per request (different audio/image
|
||||
# inputs)
|
||||
# 2. Encoder states are computed once per request, not incrementally
|
||||
# 3. No reusable prefix exists between different multimodal inputs
|
||||
# Return empty blocks to indicate no cache hits
|
||||
raise NotImplementedError(
|
||||
"CrossAttentionManager does not support caching")
|
||||
|
||||
def remove_skipped_blocks(self, request_id: str,
|
||||
num_computed_tokens: int) -> None:
|
||||
# Cross-attention blocks represent encoder states which are needed
|
||||
# for the entire decoding process, so no blocks should be skipped
|
||||
pass
|
||||
|
||||
|
||||
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
|
||||
FullAttentionSpec: FullAttentionManager,
|
||||
SlidingWindowSpec: SlidingWindowManager,
|
||||
ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager,
|
||||
MambaSpec: MambaManager,
|
||||
CrossAttentionSpec: CrossAttentionManager,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -11,6 +11,7 @@ from typing_extensions import Self
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.utils import cdiv, get_dtype_size
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -211,6 +212,20 @@ class EncoderOnlyAttentionSpec(AttentionSpec):
|
||||
return 0
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CrossAttentionSpec(AttentionSpec):
|
||||
"""
|
||||
KV cache spec for cross-attention layers in encoder-decoder models.
|
||||
"""
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
# For cross-attention, we need to cache encoder states
|
||||
# Get encoder length (e.g., 1500 for Whisper).
|
||||
max_encoder_len = MULTIMODAL_REGISTRY.\
|
||||
get_encdec_max_encoder_len(vllm_config.model_config)
|
||||
return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVCacheTensor:
|
||||
"""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user