mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:35:24 +08:00
576 lines
20 KiB
Python
576 lines
20 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""
|
|
KVConnectorBase_V1 Class for Distributed KV Cache & Hidden State
|
|
communication in vLLM v1
|
|
|
|
The class provides the following primitives:
|
|
Scheduler-side: runs in the scheduler, binds metadata, which
|
|
is used by the worker-side to load/save KV cache.
|
|
get_num_new_matched_tokens() - get number of new tokens
|
|
that exist in the remote KV cache. Might be called multiple
|
|
times for a given request and should be side-effect free.
|
|
update_state_after_alloc() - update KVConnector state after
|
|
temporary buffer alloc by the CacheManager.
|
|
update_connector_output() - update KVConnector state after
|
|
output is received from worker-side connectors.
|
|
request_finished() - called once when a request is finished,
|
|
with the computed kv cache blocks for the request.
|
|
Returns whether KV cache should be freed now or if the
|
|
connector now assumes responsibility for freeing the
|
|
the blocks asynchronously. Also optionally returns KV
|
|
transfer params.
|
|
take_events() - returns new KV events that were collected
|
|
by the connector since the last call.
|
|
|
|
Worker-side: runs in each worker, loads/saves KV cache to/from
|
|
the Connector based on the metadata.
|
|
start_load_kv() - starts loading all KVs (maybe async)
|
|
wait_for_layer_load() - blocks until layer i load is done
|
|
|
|
save_kv_layer() - starts saving KV for layer i (maybe async)
|
|
wait_for_save() - blocks until all saves are done
|
|
|
|
get_finished() - called with ids of finished requests, returns
|
|
ids of requests that have completed async sending/recving.
|
|
"""
|
|
|
|
import enum
|
|
from abc import ABC, abstractmethod
|
|
from collections.abc import Callable, Iterable
|
|
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional
|
|
|
|
import torch
|
|
|
|
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
|
|
from vllm.logger import init_logger
|
|
from vllm.v1.core.sched.output import SchedulerOutput
|
|
from vllm.v1.outputs import KVConnectorOutput
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.config import VllmConfig
|
|
from vllm.distributed.kv_events import KVCacheEvent
|
|
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
|
KVConnectorPromMetrics,
|
|
KVConnectorStats,
|
|
PromMetric,
|
|
PromMetricT,
|
|
)
|
|
from vllm.forward_context import ForwardContext
|
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
|
from vllm.v1.request import Request
|
|
|
|
# s_tensor_list, d_tensor_list, s_indices, d_indices, direction
|
|
CopyBlocksOp = Callable[
|
|
[
|
|
dict[str, torch.Tensor],
|
|
dict[str, torch.Tensor],
|
|
list[int],
|
|
list[int],
|
|
Literal["h2d", "d2h"],
|
|
],
|
|
None,
|
|
]
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class SupportsHMA(ABC):
|
|
"""
|
|
The class that indicates the corresponding connector supports hybrid memory
|
|
allocator (HMA).
|
|
This is required to use the connector together with hybrid memory allocator.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def request_finished_all_groups(
|
|
self,
|
|
request: "Request",
|
|
block_ids: tuple[list[int], ...],
|
|
) -> tuple[bool, dict[str, Any] | None]:
|
|
"""
|
|
Called exactly once when a request has finished for all kv cache groups,
|
|
before its blocks are freed for each group.
|
|
|
|
NOTE(Kuntai): This function is only supported by connectors that support HMA.
|
|
|
|
The connector may assumes responsibility for freeing the blocks
|
|
asynchronously by returning True.
|
|
|
|
Returns:
|
|
True if the request is being saved/sent asynchronously and blocks
|
|
should not be freed until the request_id is returned from
|
|
get_finished().
|
|
Optional KVTransferParams to be included in the request outputs
|
|
returned by the engine.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
def supports_hma(connector: Any) -> bool:
|
|
if isinstance(connector, type):
|
|
return issubclass(connector, SupportsHMA)
|
|
else:
|
|
return isinstance(connector, SupportsHMA)
|
|
|
|
|
|
class KVConnectorRole(enum.Enum):
|
|
# Connector running in the scheduler process
|
|
SCHEDULER = 0
|
|
|
|
# Connector running in the worker process
|
|
WORKER = 1
|
|
|
|
|
|
class KVConnectorHandshakeMetadata(ABC): # noqa: B024
|
|
"""
|
|
Metadata used for out of band connector handshake between
|
|
P/D workers. This needs to serializeable.
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
class KVConnectorMetadata(ABC): # noqa: B024
|
|
"""
|
|
Abstract Metadata used to communicate between the
|
|
Scheduler KVConnector and Worker KVConnector.
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
class KVConnectorBase_V1(ABC):
|
|
"""
|
|
Base class for KV connectors.
|
|
|
|
Attributes:
|
|
prefer_cross_layer_blocks (bool): Indicates whether this connector
|
|
prefers KV blocks that hold KV data for all layers (for speeding
|
|
up KV data transfers).
|
|
Defaults to False.
|
|
"""
|
|
|
|
prefer_cross_layer_blocks: ClassVar[bool] = False
|
|
|
|
def __init__(
|
|
self,
|
|
vllm_config: "VllmConfig",
|
|
role: KVConnectorRole,
|
|
kv_cache_config: Optional["KVCacheConfig"] = None,
|
|
):
|
|
logger.warning(
|
|
"Initializing KVConnectorBase_V1. This API is experimental and "
|
|
"subject to change in the future as we iterate the design."
|
|
)
|
|
self._connector_metadata: KVConnectorMetadata | None = None
|
|
self._vllm_config = vllm_config
|
|
if vllm_config.kv_transfer_config is not None:
|
|
self._kv_transfer_config = vllm_config.kv_transfer_config
|
|
else:
|
|
raise ValueError("kv_transfer_config must be set for KVConnectorBase_V1")
|
|
self._kv_cache_config = kv_cache_config
|
|
if self._kv_cache_config is None:
|
|
logger.warning(
|
|
"KVConnectorBase_V1 initialized without kv_cache_config. "
|
|
"This is deprecated - please update your connector to accept "
|
|
"kv_cache_config as the third constructor argument and pass it "
|
|
"to super().__init__()."
|
|
)
|
|
self._role = role
|
|
|
|
@property
|
|
def role(self) -> KVConnectorRole:
|
|
return self._role
|
|
|
|
# ==============================
|
|
# Worker-side methods
|
|
# ==============================
|
|
|
|
def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None:
|
|
"""Set the connector metadata from the scheduler.
|
|
|
|
This function should be called by the model runner every time
|
|
before the model execution. The metadata will be used for runtime
|
|
KV cache loading and saving.
|
|
|
|
Args:
|
|
connector_metadata (dict): the connector metadata.
|
|
"""
|
|
self._connector_metadata = connector_metadata
|
|
|
|
def clear_connector_metadata(self) -> None:
|
|
"""Clear the connector metadata.
|
|
|
|
This function should be called by the model runner every time
|
|
after the model execution.
|
|
"""
|
|
self._connector_metadata = None
|
|
|
|
def _get_connector_metadata(self) -> KVConnectorMetadata:
|
|
"""Get the connector metadata.
|
|
|
|
This function should only be called inside the connector.
|
|
|
|
Returns:
|
|
ConnectorMetadata: the connector metadata.
|
|
"""
|
|
# Should only be called while set to valid metadata.
|
|
assert self._connector_metadata is not None
|
|
return self._connector_metadata
|
|
|
|
def has_connector_metadata(self) -> bool:
|
|
"""Check whether the connector metadata is currently set.
|
|
|
|
Returns:
|
|
bool: True if connector metadata exists, False otherwise.
|
|
"""
|
|
return self._connector_metadata is not None
|
|
|
|
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
|
"""
|
|
Initialize with the KV caches. Useful for pre-registering the
|
|
KV Caches in the KVConnector (e.g. for NIXL).
|
|
|
|
Args:
|
|
kv_caches: dictionary of layer names, kv cache
|
|
"""
|
|
return
|
|
|
|
def register_cross_layers_kv_cache(
|
|
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
|
|
):
|
|
"""
|
|
Initialize with a single KV cache tensor used by all layers.
|
|
The first dimension should be num_layers.
|
|
This function will only be called for models with uniform layers,
|
|
and only if the prefers_cross_layer_blocks is set to True.
|
|
Only one of the functions
|
|
{register_kv_caches, register_cross_layers_kv_cache} will be called.
|
|
|
|
Args:
|
|
kv_cache: a cross-layers kv cache tensor
|
|
attn_backend: The attention backend that corresponds to all layers
|
|
"""
|
|
return
|
|
|
|
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
|
|
"""
|
|
Set the xPU-specific ops for copying KV between host and device.
|
|
Needed when host buffer is used for kv transfer (e.g., in NixlConnector)
|
|
"""
|
|
return
|
|
|
|
@abstractmethod
|
|
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
|
|
"""
|
|
Start loading the KV cache from the connector to vLLM's paged
|
|
KV buffer. This is called from the forward context before the
|
|
forward pass to enable async loading during model execution.
|
|
|
|
Args:
|
|
forward_context (ForwardContext): the forward context.
|
|
**kwargs: additional arguments for the load operation
|
|
|
|
Note:
|
|
The number of elements in kv_caches and layer_names should be
|
|
the same.
|
|
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def wait_for_layer_load(self, layer_name: str) -> None:
|
|
"""
|
|
Block until the KV for a specific layer is loaded into vLLM's
|
|
paged buffer. This is called from within attention layer to ensure
|
|
async copying from start_load_kv is complete.
|
|
|
|
This interface will be useful for layer-by-layer pipelining.
|
|
|
|
Args:
|
|
layer_name: the name of that layer
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def save_kv_layer(
|
|
self,
|
|
layer_name: str,
|
|
kv_layer: torch.Tensor,
|
|
attn_metadata: "AttentionMetadata",
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""
|
|
Start saving a layer of KV cache from vLLM's paged buffer
|
|
to the connector. This is called from within attention layer to
|
|
enable async copying during execution.
|
|
|
|
Args:
|
|
layer_name (str): the name of the layer.
|
|
kv_layer (torch.Tensor): the paged KV buffer of the current
|
|
layer in vLLM.
|
|
attn_metadata (AttentionMetadata): the attention metadata.
|
|
**kwargs: additional arguments for the save operation.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def wait_for_save(self):
|
|
"""
|
|
Block until all the save operations is done. This is called
|
|
as the forward context exits to ensure that the async saving
|
|
from save_kv_layer is complete before finishing the forward.
|
|
|
|
This prevents overwrites of paged KV buffer before saving done.
|
|
"""
|
|
pass
|
|
|
|
def get_finished(
|
|
self, finished_req_ids: set[str]
|
|
) -> tuple[set[str] | None, set[str] | None]:
|
|
"""
|
|
Notifies worker-side connector ids of requests that have
|
|
finished generating tokens on the worker.
|
|
The scheduler process (via the Executors) will use this output
|
|
to track which workers are done.
|
|
|
|
Returns:
|
|
ids of requests that have finished asynchronous transfer
|
|
(requests that previously returned True from request_finished()),
|
|
tuple of (sending/saving ids, recving/loading ids).
|
|
The finished saves/sends req ids must belong to a set provided in a
|
|
call to this method (this call or a prior one).
|
|
"""
|
|
return None, None
|
|
|
|
def get_block_ids_with_load_errors(self) -> set[int]:
|
|
"""
|
|
Get the set of block IDs that failed to load.
|
|
|
|
Returns:
|
|
Set of block IDs that encountered load errors.
|
|
Empty set if no load errors occurred.
|
|
|
|
Notes:
|
|
- Applies to both sync- and async-loading requests.
|
|
- Async loading: failed blocks may be reported in any forward pass
|
|
up to and including the pass where the request ID is returned by
|
|
`get_finished()`. Even if failures occur, the request must still
|
|
be reported via `get_finished()`, and the failed block IDs must
|
|
appear here no later than that same pass.
|
|
- Sync loading: failed blocks should be reported in the forward
|
|
pass in which they are detected.
|
|
"""
|
|
return set()
|
|
|
|
def shutdown(self):
|
|
"""
|
|
Shutdown the connector. This is called when the worker process
|
|
is shutting down to ensure that all the async operations are
|
|
completed and the connector is cleaned up properly.
|
|
"""
|
|
return None
|
|
|
|
def get_kv_connector_stats(self) -> Optional["KVConnectorStats"]:
|
|
"""
|
|
Get the KV connector stats collected during the last interval.
|
|
"""
|
|
return None
|
|
|
|
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
|
|
"""
|
|
Get the KVConnector handshake metadata for this connector.
|
|
This metadata is used for out-of-band connector handshake
|
|
between P/D workers.
|
|
|
|
Returns:
|
|
KVConnectorHandshakeMetadata: the handshake metadata.
|
|
None if no handshake metadata is available.
|
|
"""
|
|
return None
|
|
|
|
# ==============================
|
|
# Scheduler-side methods
|
|
# ==============================
|
|
|
|
@abstractmethod
|
|
def get_num_new_matched_tokens(
|
|
self,
|
|
request: "Request",
|
|
num_computed_tokens: int,
|
|
) -> tuple[int | None, bool]:
|
|
"""
|
|
Get number of new tokens that can be loaded from the
|
|
external KV cache beyond the num_computed_tokens.
|
|
|
|
Args:
|
|
request (Request): the request object.
|
|
num_computed_tokens (int): the number of locally
|
|
computed tokens for this request
|
|
|
|
Returns:
|
|
A tuple with the following elements:
|
|
- An optional number of tokens that can be loaded from the
|
|
external KV cache beyond what is already computed.
|
|
If None, it means that the connector needs more time to
|
|
determine the number of matched tokens, and the scheduler
|
|
should query for this request again later.
|
|
- `True` if external KV cache tokens will be loaded
|
|
asynchronously (between scheduler steps). Must be
|
|
'False' if the first element is 0.
|
|
|
|
Notes:
|
|
The connector should only consider the largest prefix of prompt-
|
|
tokens for which KV cache is actually available at the time of the
|
|
call. If the cache cannot be loaded for some tokens (e.g., due to
|
|
connectivity issues or eviction), those tokens must not be taken
|
|
into account.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def update_state_after_alloc(
|
|
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
|
):
|
|
"""
|
|
Update KVConnector state after block allocation.
|
|
|
|
If get_num_new_matched_tokens previously returned True for a
|
|
request, this function may be called twice for that same request -
|
|
first when blocks are allocated for the connector tokens to be
|
|
asynchronously loaded into, and second when any additional blocks
|
|
are allocated, after the load/transfer is complete.
|
|
|
|
Args:
|
|
request (Request): the request object.
|
|
blocks (KVCacheBlocks): the blocks allocated for the request.
|
|
num_external_tokens (int): the number of tokens that will be
|
|
loaded from the external KV cache.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def build_connector_meta(
|
|
self, scheduler_output: SchedulerOutput
|
|
) -> KVConnectorMetadata:
|
|
"""
|
|
Build the connector metadata for this step.
|
|
|
|
This function should NOT modify fields in the scheduler_output.
|
|
Also, calling this function will reset the state of the connector.
|
|
|
|
Args:
|
|
scheduler_output (SchedulerOutput): the scheduler output object.
|
|
"""
|
|
pass
|
|
|
|
def update_connector_output(self, connector_output: KVConnectorOutput):
|
|
"""
|
|
Update KVConnector state from worker-side connectors output.
|
|
|
|
Args:
|
|
connector_output (KVConnectorOutput): the worker-side
|
|
connectors output.
|
|
"""
|
|
return
|
|
|
|
def request_finished(
|
|
self,
|
|
request: "Request",
|
|
block_ids: list[int],
|
|
) -> tuple[bool, dict[str, Any] | None]:
|
|
"""
|
|
Called exactly once when a request has finished, before its blocks are
|
|
freed.
|
|
|
|
The connector may assumes responsibility for freeing the blocks
|
|
asynchronously by returning True.
|
|
|
|
Returns:
|
|
True if the request is being saved/sent asynchronously and blocks
|
|
should not be freed until the request_id is returned from
|
|
get_finished().
|
|
Optional KVTransferParams to be included in the request outputs
|
|
returned by the engine.
|
|
"""
|
|
return False, None
|
|
|
|
def take_events(self) -> Iterable["KVCacheEvent"]:
|
|
"""
|
|
Take the KV cache events from the connector.
|
|
|
|
Yields:
|
|
New KV cache events since the last call.
|
|
"""
|
|
return ()
|
|
|
|
@classmethod
|
|
def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None:
|
|
"""
|
|
Get the required KV cache layout for this connector.
|
|
Args:
|
|
vllm_config (VllmConfig): the vllm config.
|
|
|
|
Returns:
|
|
str: the required KV cache layout. e.g. HND, or NHD.
|
|
None if the connector does not require a specific layout.
|
|
"""
|
|
|
|
if cls is KVConnectorBase_V1:
|
|
raise TypeError(
|
|
"get_required_kvcache_layout should not be called "
|
|
"on the abstract base class"
|
|
)
|
|
return None
|
|
|
|
def get_finished_count(self) -> int | None:
|
|
"""
|
|
Get the count of requests expected to complete send/receive operations
|
|
via this connector. This method is used to initialize the
|
|
KVOutputAggregator, overwriting the default world_size.
|
|
|
|
Returns:
|
|
int: expected sending or receiving completion count.
|
|
"""
|
|
|
|
return None
|
|
|
|
@classmethod
|
|
def build_kv_connector_stats(
|
|
cls, data: dict[str, Any] | None = None
|
|
) -> Optional["KVConnectorStats"]:
|
|
"""
|
|
KVConnectorStats resolution method. This method allows dynamically
|
|
registered connectors to return their own KVConnectorStats object,
|
|
which can implement custom aggregation logic on the data dict.
|
|
"""
|
|
return None
|
|
|
|
def set_xfer_handshake_metadata(
|
|
self, metadata: dict[int, KVConnectorHandshakeMetadata]
|
|
) -> None:
|
|
"""
|
|
Set the KV connector handshake metadata for this connector.
|
|
|
|
Args:
|
|
metadata (KVConnectorHandshakeMetadata): the handshake metadata to set.
|
|
"""
|
|
return None
|
|
|
|
@classmethod
|
|
def build_prom_metrics(
|
|
cls,
|
|
vllm_config: "VllmConfig",
|
|
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
|
|
labelnames: list[str],
|
|
per_engine_labelvalues: dict[int, list[object]],
|
|
) -> Optional["KVConnectorPromMetrics"]:
|
|
"""
|
|
Create a KVConnectorPromMetrics subclass which should register
|
|
per-connector Prometheus metrics and implement observe() to
|
|
expose connector transfer stats via Prometheus.
|
|
"""
|
|
return None
|