[KVConnector] Keep KVTransferParams as a dict (#18033)

This commit is contained in:
Nick Hill 2025-05-14 08:05:57 -07:00 committed by GitHub
parent d066e52013
commit 59dd311cf5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 64 additions and 157 deletions

View File

@ -1,13 +1,11 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Optional from typing import Any, Optional
import torch import torch
from vllm import SamplingParams from vllm import SamplingParams
from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig, from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig,
ModelConfig, SchedulerConfig, VllmConfig) ModelConfig, SchedulerConfig, VllmConfig)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlKVTransferParams)
from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec) KVCacheGroupSpec)
@ -124,20 +122,20 @@ def create_request(
) -> Request: ) -> Request:
"""Make dummy request for testing.""" """Make dummy request for testing."""
kv_transfer_params: Optional[dict[str, Any]] = None
if do_remote_decode: if do_remote_decode:
assert not do_remote_prefill assert not do_remote_prefill
kv_transfer_params = NixlKVTransferParams(do_remote_prefill=False, kv_transfer_params = dict(do_remote_prefill=False,
do_remote_decode=True) do_remote_decode=True)
elif do_remote_prefill: elif do_remote_prefill:
kv_transfer_params = NixlKVTransferParams( kv_transfer_params = dict(do_remote_prefill=True,
do_remote_prefill=True,
do_remote_decode=False, do_remote_decode=False,
remote_engine_id="my-engine-id", remote_engine_id="my-engine-id",
remote_block_ids=list(range(num_remote_blocks)), remote_block_ids=list(
range(num_remote_blocks)),
remote_host="my-host", remote_host="my-host",
remote_port=1234) remote_port=1234)
else:
kv_transfer_params = None
max_tokens = 1 if do_remote_decode else max_tokens max_tokens = 1 if do_remote_decode else max_tokens
sampling_params = SamplingParams(max_tokens=max_tokens) sampling_params = SamplingParams(max_tokens=max_tokens)

View File

@ -1,5 +1,5 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorRole, KVTransferParams) KVConnectorBase_V1, KVConnectorRole)
__all__ = ["KVConnectorRole", "KVConnectorBase_V1", "KVTransferParams"] __all__ = ["KVConnectorRole", "KVConnectorBase_V1"]

View File

@ -48,23 +48,6 @@ class KVConnectorRole(enum.Enum):
WORKER = 1 WORKER = 1
class KVTransferParams:
"""
Abstract KVTransferParams used to send KVTransfer
parameters between instances of vLLM.
Specific instances of KVConnector customize this
method for serializing / deserializing msgs sent
via the HTTP protocol.
"""
@staticmethod
def from_raw_dict(
raw_dict: Optional[dict[str,
Any]]) -> Optional["KVTransferParams"]:
return None
@dataclass @dataclass
class KVConnectorMetadata: class KVConnectorMetadata:
""" """
@ -75,7 +58,6 @@ class KVConnectorMetadata:
class KVConnectorBase_V1(ABC): class KVConnectorBase_V1(ABC):
_KVTransferParams = KVTransferParams
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
logger.warning( logger.warning(
@ -213,13 +195,6 @@ class KVConnectorBase_V1(ABC):
# Scheduler-side methods # Scheduler-side methods
# ============================== # ==============================
def set_kv_transfer_params(self, request: "Request"):
"""Parse raw KV Transfer params."""
assert request.kv_transfer_params is None
kv_transfer_params = self._KVTransferParams.from_raw_dict(
request.raw_kv_transfer_params)
request.kv_transfer_params = kv_transfer_params
@abstractmethod @abstractmethod
def get_num_new_matched_tokens( def get_num_new_matched_tokens(
self, self,

View File

@ -16,7 +16,7 @@ import zmq
from vllm import envs from vllm import envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole, KVTransferParams) KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
get_tp_group) get_tp_group)
@ -44,56 +44,6 @@ except ImportError:
NixlWrapper = None NixlWrapper = None
@dataclass
class NixlKVTransferParams(KVTransferParams):
def __init__(
self,
do_remote_prefill: bool,
do_remote_decode: bool,
remote_block_ids: Optional[list[int]] = None,
remote_host: Optional[str] = None,
remote_port: Optional[int] = None,
remote_engine_id: Optional[str] = None,
):
self.do_remote_prefill = do_remote_prefill
self.do_remote_decode = do_remote_decode
self.remote_block_ids = remote_block_ids
self.remote_host = remote_host
self.remote_port = remote_port
self.remote_engine_id = remote_engine_id
@staticmethod
def from_raw_dict(
raw_dict: Optional[dict[str,
Any]]) -> Optional["NixlKVTransferParams"]:
# If no raw transfer params passed, return None.
if raw_dict is None:
return None
# Validate the request is formatted properly.
if (("do_remote_prefill" not in raw_dict)
or ("do_remote_decode" not in raw_dict)
or ("remote_block_ids" not in raw_dict)
or ("remote_host" not in raw_dict)
or ("remote_port" not in raw_dict)
or ("remote_engine_id" not in raw_dict)):
logger.warning(
"Got invalid KVTransferParams: %s. This "
"request will not utilize KVTransfer", raw_dict)
return None
return NixlKVTransferParams(
do_remote_prefill=raw_dict["do_remote_prefill"],
do_remote_decode=raw_dict["do_remote_decode"],
remote_block_ids=raw_dict["remote_block_ids"],
remote_host=raw_dict["remote_host"],
remote_port=raw_dict["remote_port"],
remote_engine_id=raw_dict["remote_engine_id"],
)
class NixlAgentMetadata( class NixlAgentMetadata(
msgspec.Struct, msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg] omit_defaults=True, # type: ignore[call-arg]
@ -123,25 +73,18 @@ class NixlConnectorMetadata(KVConnectorMetadata):
self, self,
request_id: str, request_id: str,
local_block_ids: list[int], local_block_ids: list[int],
kv_transfer_params: NixlKVTransferParams, kv_transfer_params: dict[str, Any],
): ):
assert request_id not in self.requests
assert kv_transfer_params.remote_block_ids is not None
assert kv_transfer_params.remote_engine_id is not None
assert kv_transfer_params.remote_host is not None
assert kv_transfer_params.remote_port is not None
self.requests[request_id] = ReqMeta( self.requests[request_id] = ReqMeta(
local_block_ids=local_block_ids, local_block_ids=local_block_ids,
remote_block_ids=kv_transfer_params.remote_block_ids, remote_block_ids=kv_transfer_params["remote_block_ids"],
remote_engine_id=kv_transfer_params.remote_engine_id, remote_engine_id=kv_transfer_params["remote_engine_id"],
remote_host=kv_transfer_params.remote_host, remote_host=kv_transfer_params["remote_host"],
remote_port=kv_transfer_params.remote_port, remote_port=kv_transfer_params["remote_port"],
) )
class NixlConnector(KVConnectorBase_V1): class NixlConnector(KVConnectorBase_V1):
_KVTransferParams: type[NixlKVTransferParams] = NixlKVTransferParams
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
assert vllm_config.kv_transfer_config is not None assert vllm_config.kv_transfer_config is not None
@ -253,52 +196,52 @@ class NixlConnectorScheduler:
asynchronously (between scheduler steps). asynchronously (between scheduler steps).
""" """
params = request.kv_transfer_params
logger.debug( logger.debug(
"NIXLConnector get_num_new_matched_tokens: " "NIXLConnector get_num_new_matched_tokens: "
"num_computed_tokens=%s, kv_transfer_params=%s", "num_computed_tokens=%s, kv_transfer_params=%s",
num_computed_tokens, request.kv_transfer_params) num_computed_tokens, params)
# No KVTransfer for this request.
if request.kv_transfer_params is None:
return 0, False
assert isinstance(request.kv_transfer_params, NixlKVTransferParams)
if params is not None and params.get("do_remote_prefill"):
# Remote prefill: get all prompt blocks from remote. # Remote prefill: get all prompt blocks from remote.
if request.kv_transfer_params.do_remote_prefill:
assert num_computed_tokens % self.block_size == 0 assert num_computed_tokens % self.block_size == 0
rounded_num_prompt_tokens = round_down( rounded_num_prompt_tokens = round_down(
len(request.prompt_token_ids), self.block_size) len(request.prompt_token_ids), self.block_size)
count = max(rounded_num_prompt_tokens - num_computed_tokens, 0) count = max(rounded_num_prompt_tokens - num_computed_tokens, 0)
return count, count > 0 return count, count > 0
# No remote prefill for this request.
return 0, False return 0, False
def update_state_after_alloc(self, request: "Request", def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks", blocks: "KVCacheBlocks",
num_external_tokens: int): num_external_tokens: int):
params = request.kv_transfer_params
logger.debug( logger.debug(
"NIXLConnector update_state_after_alloc: " "NIXLConnector update_state_after_alloc: "
"num_external_tokens=%s, kv_transfer_params=%s", "num_external_tokens=%s, kv_transfer_params=%s",
num_external_tokens, request.kv_transfer_params) num_external_tokens, params)
if request.kv_transfer_params is None: if params is not None and params.get("do_remote_prefill"):
return
assert isinstance(request.kv_transfer_params, NixlKVTransferParams)
if request.kv_transfer_params.do_remote_prefill:
# NOTE(rob): if prompt < block_size, no remote blocks # NOTE(rob): if prompt < block_size, no remote blocks
# since the remote only sends fully computed blocks, so # since the remote only sends fully computed blocks, so
# skip recving for this request. num_external_tokens # skip recving for this request. num_external_tokens
# should be 0 if there are no remote blocks. # should be 0 if there are no remote blocks.
if request.kv_transfer_params.remote_block_ids: if params.get("remote_block_ids"):
if all(p in params for p in ("remote_engine_id", "remote_host",
"remote_port")):
# Get unhashed blocks to pull from remote. # Get unhashed blocks to pull from remote.
self._reqs_need_recv[request.request_id] = ( self._reqs_need_recv[request.request_id] = (
request, blocks.get_unhashed_block_ids()) request, blocks.get_unhashed_block_ids())
else:
logger.warning(
"Got invalid KVTransferParams: %s. This "
"request will not utilize KVTransfer", params)
else: else:
assert num_external_tokens == 0 assert num_external_tokens == 0
# Only trigger 1 KV transfer per request. # Only trigger 1 KV transfer per request.
request.kv_transfer_params.do_remote_prefill = False params["do_remote_prefill"] = False
def build_connector_meta( def build_connector_meta(
self, self,
@ -308,7 +251,7 @@ class NixlConnectorScheduler:
# Loop through scheduled reqs and convert to ReqMeta. # Loop through scheduled reqs and convert to ReqMeta.
for req_id, (req, block_ids) in self._reqs_need_recv.items(): for req_id, (req, block_ids) in self._reqs_need_recv.items():
assert isinstance(req.kv_transfer_params, NixlKVTransferParams) assert req.kv_transfer_params is not None
meta.add_new_req( meta.add_new_req(
request_id=req_id, request_id=req_id,
local_block_ids=block_ids, local_block_ids=block_ids,
@ -330,34 +273,30 @@ class NixlConnectorScheduler:
should be freed now or will be sent asynchronously and freed later. should be freed now or will be sent asynchronously and freed later.
""" """
params = request.kv_transfer_params
logger.debug( logger.debug(
"NIXLConnector request_finished, " "NIXLConnector request_finished, request_status=%s, "
"request_status=%s, kv_transfer_params=%s", request.status, "kv_transfer_params=%s", request.status, params)
request.kv_transfer_params)
if request.kv_transfer_params is None: if (params is None or not params.get("do_remote_decode")
return False, None or request.status != RequestStatus.FINISHED_LENGTH_CAPPED):
assert isinstance(request.kv_transfer_params, NixlKVTransferParams)
if ((not request.kv_transfer_params.do_remote_decode)
or (request.status != RequestStatus.FINISHED_LENGTH_CAPPED)):
return False, None return False, None
# Get computed blocks. # Get computed blocks.
all_full = request.num_computed_tokens % self.block_size == 0 all_full = request.num_computed_tokens % self.block_size == 0
computed_block_ids = (block_ids if all_full else block_ids[:-1]) computed_block_ids = block_ids if all_full else block_ids[:-1]
# If prompt < block_size, no xfer so free blocks immediately. # If prompt < block_size, no xfer so free blocks immediately.
delay_free_blocks = len(computed_block_ids) > 0 delay_free_blocks = len(computed_block_ids) > 0
return delay_free_blocks, NixlKVTransferParams( return delay_free_blocks, dict(
do_remote_prefill=True, do_remote_prefill=True,
do_remote_decode=False, do_remote_decode=False,
remote_block_ids=computed_block_ids, remote_block_ids=computed_block_ids,
remote_engine_id=self.engine_id, remote_engine_id=self.engine_id,
remote_host=envs.VLLM_NIXL_SIDE_CHANNEL_HOST, remote_host=envs.VLLM_NIXL_SIDE_CHANNEL_HOST,
remote_port=envs.VLLM_NIXL_SIDE_CHANNEL_PORT, remote_port=envs.VLLM_NIXL_SIDE_CHANNEL_PORT,
).__dict__ )
class NixlConnectorWorker: class NixlConnectorWorker:

View File

@ -12,8 +12,7 @@ from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
from vllm.distributed.kv_transfer.kv_connector.factory import ( from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory) KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
KVConnectorRole, KVConnectorRole)
KVTransferParams)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
@ -931,8 +930,13 @@ class Scheduler(SchedulerInterface):
return self.connector return self.connector
def _connector_finished( def _connector_finished(
self, request: Request) -> tuple[bool, Optional[KVTransferParams]]: self, request: Request) -> tuple[bool, Optional[dict[str, Any]]]:
"""Invoke the KV connector request_finished() method if applicable.""" """
Invoke the KV connector request_finished() method if applicable.
Returns optional kv transfer parameters to be included with the
request outputs.
"""
if self.connector is None: if self.connector is None:
return False, None return False, None
block_ids = self.kv_cache_manager.get_block_ids(request.request_id) block_ids = self.kv_cache_manager.get_block_ids(request.request_id)

View File

@ -182,13 +182,9 @@ class EngineCore:
# Start grammar compilation asynchronously # Start grammar compilation asynchronously
self.structured_output_manager.grammar_init(req) self.structured_output_manager.grammar_init(req)
if req.raw_kv_transfer_params is not None: if req.kv_transfer_params is not None and (
if (kv_connector := self.scheduler.get_kv_connector()): not self.scheduler.get_kv_connector()):
# Parse raw KV transfer params via connector. logger.warning("Got kv_transfer_params, but no KVConnector found. "
kv_connector.set_kv_transfer_params(req)
else:
logger.warning(
"Got KVTransferParams, but no KVConnector found. "
"Disabling KVTransfer for this request.") "Disabling KVTransfer for this request.")
self.scheduler.add_request(req) self.scheduler.add_request(req)

View File

@ -3,7 +3,6 @@
import enum import enum
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
from vllm.distributed.kv_transfer.kv_connector.v1 import KVTransferParams
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import is_list_of from vllm.utils import is_list_of
@ -62,14 +61,10 @@ class Request:
self.num_encoder_inputs = len(self.mm_inputs) self.num_encoder_inputs = len(self.mm_inputs)
self.has_encoder_inputs = self.num_encoder_inputs > 0 self.has_encoder_inputs = self.num_encoder_inputs > 0
# P/D: KV transfer parameters (raw and parsed). # P/D: Connector-specific KV transfer parameters.
raw_params = (None if sampling_params.extra_args is None kv_params = (None if sampling_params.extra_args is None else
else sampling_params.extra_args.get( sampling_params.extra_args.get("kv_transfer_params"))
"kv_transfer_params", None)) self.kv_transfer_params: Optional[dict[str, Any]] = kv_params
self.raw_kv_transfer_params: Optional[dict[str, Any]] = raw_params
# Each connector parses the raw dictionary and sets this
# attr the first time that the request is processed.
self.kv_transfer_params: Optional[KVTransferParams] = None
# Sanity check # Sanity check
assert len(self.mm_inputs) == len(self.mm_positions) assert len(self.mm_inputs) == len(self.mm_positions)