[KVConnector] Keep KVTransferParams as a dict (#18033)

This commit is contained in:
Nick Hill 2025-05-14 08:05:57 -07:00 committed by GitHub
parent d066e52013
commit 59dd311cf5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 64 additions and 157 deletions

View File

@ -1,13 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
from typing import Any, Optional
import torch
from vllm import SamplingParams
from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig,
ModelConfig, SchedulerConfig, VllmConfig)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlKVTransferParams)
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec)
@ -124,20 +122,20 @@ def create_request(
) -> Request:
"""Make dummy request for testing."""
kv_transfer_params: Optional[dict[str, Any]] = None
if do_remote_decode:
assert not do_remote_prefill
kv_transfer_params = NixlKVTransferParams(do_remote_prefill=False,
do_remote_decode=True)
kv_transfer_params = dict(do_remote_prefill=False,
do_remote_decode=True)
elif do_remote_prefill:
kv_transfer_params = NixlKVTransferParams(
do_remote_prefill=True,
do_remote_decode=False,
remote_engine_id="my-engine-id",
remote_block_ids=list(range(num_remote_blocks)),
remote_host="my-host",
remote_port=1234)
else:
kv_transfer_params = None
kv_transfer_params = dict(do_remote_prefill=True,
do_remote_decode=False,
remote_engine_id="my-engine-id",
remote_block_ids=list(
range(num_remote_blocks)),
remote_host="my-host",
remote_port=1234)
max_tokens = 1 if do_remote_decode else max_tokens
sampling_params = SamplingParams(max_tokens=max_tokens)

View File

@ -1,5 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorRole, KVTransferParams)
KVConnectorBase_V1, KVConnectorRole)
__all__ = ["KVConnectorRole", "KVConnectorBase_V1", "KVTransferParams"]
__all__ = ["KVConnectorRole", "KVConnectorBase_V1"]

View File

@ -48,23 +48,6 @@ class KVConnectorRole(enum.Enum):
WORKER = 1
class KVTransferParams:
"""
Abstract KVTransferParams used to send KVTransfer
parameters between instances of vLLM.
Specific instances of KVConnector customize this
method for serializing / deserializing msgs sent
via the HTTP protocol.
"""
@staticmethod
def from_raw_dict(
raw_dict: Optional[dict[str,
Any]]) -> Optional["KVTransferParams"]:
return None
@dataclass
class KVConnectorMetadata:
"""
@ -75,7 +58,6 @@ class KVConnectorMetadata:
class KVConnectorBase_V1(ABC):
_KVTransferParams = KVTransferParams
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
logger.warning(
@ -213,13 +195,6 @@ class KVConnectorBase_V1(ABC):
# Scheduler-side methods
# ==============================
def set_kv_transfer_params(self, request: "Request"):
"""Parse raw KV Transfer params."""
assert request.kv_transfer_params is None
kv_transfer_params = self._KVTransferParams.from_raw_dict(
request.raw_kv_transfer_params)
request.kv_transfer_params = kv_transfer_params
@abstractmethod
def get_num_new_matched_tokens(
self,

View File

@ -16,7 +16,7 @@ import zmq
from vllm import envs
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole, KVTransferParams)
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
get_tp_group)
@ -44,56 +44,6 @@ except ImportError:
NixlWrapper = None
@dataclass
class NixlKVTransferParams(KVTransferParams):
def __init__(
self,
do_remote_prefill: bool,
do_remote_decode: bool,
remote_block_ids: Optional[list[int]] = None,
remote_host: Optional[str] = None,
remote_port: Optional[int] = None,
remote_engine_id: Optional[str] = None,
):
self.do_remote_prefill = do_remote_prefill
self.do_remote_decode = do_remote_decode
self.remote_block_ids = remote_block_ids
self.remote_host = remote_host
self.remote_port = remote_port
self.remote_engine_id = remote_engine_id
@staticmethod
def from_raw_dict(
raw_dict: Optional[dict[str,
Any]]) -> Optional["NixlKVTransferParams"]:
# If no raw transfer params passed, return None.
if raw_dict is None:
return None
# Validate the request is formatted properly.
if (("do_remote_prefill" not in raw_dict)
or ("do_remote_decode" not in raw_dict)
or ("remote_block_ids" not in raw_dict)
or ("remote_host" not in raw_dict)
or ("remote_port" not in raw_dict)
or ("remote_engine_id" not in raw_dict)):
logger.warning(
"Got invalid KVTransferParams: %s. This "
"request will not utilize KVTransfer", raw_dict)
return None
return NixlKVTransferParams(
do_remote_prefill=raw_dict["do_remote_prefill"],
do_remote_decode=raw_dict["do_remote_decode"],
remote_block_ids=raw_dict["remote_block_ids"],
remote_host=raw_dict["remote_host"],
remote_port=raw_dict["remote_port"],
remote_engine_id=raw_dict["remote_engine_id"],
)
class NixlAgentMetadata(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
@ -123,25 +73,18 @@ class NixlConnectorMetadata(KVConnectorMetadata):
self,
request_id: str,
local_block_ids: list[int],
kv_transfer_params: NixlKVTransferParams,
kv_transfer_params: dict[str, Any],
):
assert request_id not in self.requests
assert kv_transfer_params.remote_block_ids is not None
assert kv_transfer_params.remote_engine_id is not None
assert kv_transfer_params.remote_host is not None
assert kv_transfer_params.remote_port is not None
self.requests[request_id] = ReqMeta(
local_block_ids=local_block_ids,
remote_block_ids=kv_transfer_params.remote_block_ids,
remote_engine_id=kv_transfer_params.remote_engine_id,
remote_host=kv_transfer_params.remote_host,
remote_port=kv_transfer_params.remote_port,
remote_block_ids=kv_transfer_params["remote_block_ids"],
remote_engine_id=kv_transfer_params["remote_engine_id"],
remote_host=kv_transfer_params["remote_host"],
remote_port=kv_transfer_params["remote_port"],
)
class NixlConnector(KVConnectorBase_V1):
_KVTransferParams: type[NixlKVTransferParams] = NixlKVTransferParams
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
assert vllm_config.kv_transfer_config is not None
@ -253,52 +196,52 @@ class NixlConnectorScheduler:
asynchronously (between scheduler steps).
"""
params = request.kv_transfer_params
logger.debug(
"NIXLConnector get_num_new_matched_tokens: "
"num_computed_tokens=%s, kv_transfer_params=%s",
num_computed_tokens, request.kv_transfer_params)
num_computed_tokens, params)
# No KVTransfer for this request.
if request.kv_transfer_params is None:
return 0, False
assert isinstance(request.kv_transfer_params, NixlKVTransferParams)
# Remote prefill: get all prompt blocks from remote.
if request.kv_transfer_params.do_remote_prefill:
if params is not None and params.get("do_remote_prefill"):
# Remote prefill: get all prompt blocks from remote.
assert num_computed_tokens % self.block_size == 0
rounded_num_prompt_tokens = round_down(
len(request.prompt_token_ids), self.block_size)
count = max(rounded_num_prompt_tokens - num_computed_tokens, 0)
return count, count > 0
# No remote prefill for this request.
return 0, False
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
params = request.kv_transfer_params
logger.debug(
"NIXLConnector update_state_after_alloc: "
"num_external_tokens=%s, kv_transfer_params=%s",
num_external_tokens, request.kv_transfer_params)
num_external_tokens, params)
if request.kv_transfer_params is None:
return
assert isinstance(request.kv_transfer_params, NixlKVTransferParams)
if request.kv_transfer_params.do_remote_prefill:
if params is not None and params.get("do_remote_prefill"):
# NOTE(rob): if prompt < block_size, no remote blocks
# since the remote only sends fully computed blocks, so
# skip recving for this request. num_external_tokens
# should be 0 if there are no remote blocks.
if request.kv_transfer_params.remote_block_ids:
# Get unhashed blocks to pull from remote.
self._reqs_need_recv[request.request_id] = (
request, blocks.get_unhashed_block_ids())
if params.get("remote_block_ids"):
if all(p in params for p in ("remote_engine_id", "remote_host",
"remote_port")):
# Get unhashed blocks to pull from remote.
self._reqs_need_recv[request.request_id] = (
request, blocks.get_unhashed_block_ids())
else:
logger.warning(
"Got invalid KVTransferParams: %s. This "
"request will not utilize KVTransfer", params)
else:
assert num_external_tokens == 0
# Only trigger 1 KV transfer per request.
request.kv_transfer_params.do_remote_prefill = False
params["do_remote_prefill"] = False
def build_connector_meta(
self,
@ -308,7 +251,7 @@ class NixlConnectorScheduler:
# Loop through scheduled reqs and convert to ReqMeta.
for req_id, (req, block_ids) in self._reqs_need_recv.items():
assert isinstance(req.kv_transfer_params, NixlKVTransferParams)
assert req.kv_transfer_params is not None
meta.add_new_req(
request_id=req_id,
local_block_ids=block_ids,
@ -330,34 +273,30 @@ class NixlConnectorScheduler:
should be freed now or will be sent asynchronously and freed later.
"""
params = request.kv_transfer_params
logger.debug(
"NIXLConnector request_finished, "
"request_status=%s, kv_transfer_params=%s", request.status,
request.kv_transfer_params)
"NIXLConnector request_finished, request_status=%s, "
"kv_transfer_params=%s", request.status, params)
if request.kv_transfer_params is None:
return False, None
assert isinstance(request.kv_transfer_params, NixlKVTransferParams)
if ((not request.kv_transfer_params.do_remote_decode)
or (request.status != RequestStatus.FINISHED_LENGTH_CAPPED)):
if (params is None or not params.get("do_remote_decode")
or request.status != RequestStatus.FINISHED_LENGTH_CAPPED):
return False, None
# Get computed blocks.
all_full = request.num_computed_tokens % self.block_size == 0
computed_block_ids = (block_ids if all_full else block_ids[:-1])
computed_block_ids = block_ids if all_full else block_ids[:-1]
# If prompt < block_size, no xfer so free blocks immediately.
delay_free_blocks = len(computed_block_ids) > 0
return delay_free_blocks, NixlKVTransferParams(
return delay_free_blocks, dict(
do_remote_prefill=True,
do_remote_decode=False,
remote_block_ids=computed_block_ids,
remote_engine_id=self.engine_id,
remote_host=envs.VLLM_NIXL_SIDE_CHANNEL_HOST,
remote_port=envs.VLLM_NIXL_SIDE_CHANNEL_PORT,
).__dict__
)
class NixlConnectorWorker:

View File

@ -12,8 +12,7 @@ from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
KVConnectorRole,
KVTransferParams)
KVConnectorRole)
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
@ -931,8 +930,13 @@ class Scheduler(SchedulerInterface):
return self.connector
def _connector_finished(
self, request: Request) -> tuple[bool, Optional[KVTransferParams]]:
"""Invoke the KV connector request_finished() method if applicable."""
self, request: Request) -> tuple[bool, Optional[dict[str, Any]]]:
"""
Invoke the KV connector request_finished() method if applicable.
Returns optional kv transfer parameters to be included with the
request outputs.
"""
if self.connector is None:
return False, None
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)

View File

@ -182,14 +182,10 @@ class EngineCore:
# Start grammar compilation asynchronously
self.structured_output_manager.grammar_init(req)
if req.raw_kv_transfer_params is not None:
if (kv_connector := self.scheduler.get_kv_connector()):
# Parse raw KV transfer params via connector.
kv_connector.set_kv_transfer_params(req)
else:
logger.warning(
"Got KVTransferParams, but no KVConnector found. "
"Disabling KVTransfer for this request.")
if req.kv_transfer_params is not None and (
not self.scheduler.get_kv_connector()):
logger.warning("Got kv_transfer_params, but no KVConnector found. "
"Disabling KVTransfer for this request.")
self.scheduler.add_request(req)

View File

@ -3,7 +3,6 @@
import enum
from typing import TYPE_CHECKING, Any, Optional, Union
from vllm.distributed.kv_transfer.kv_connector.v1 import KVTransferParams
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.utils import is_list_of
@ -62,14 +61,10 @@ class Request:
self.num_encoder_inputs = len(self.mm_inputs)
self.has_encoder_inputs = self.num_encoder_inputs > 0
# P/D: KV transfer parameters (raw and parsed).
raw_params = (None if sampling_params.extra_args is None
else sampling_params.extra_args.get(
"kv_transfer_params", None))
self.raw_kv_transfer_params: Optional[dict[str, Any]] = raw_params
# Each connector parses the raw dictionary and sets this
# attr the first time that the request is processed.
self.kv_transfer_params: Optional[KVTransferParams] = None
# P/D: Connector-specific KV transfer parameters.
kv_params = (None if sampling_params.extra_args is None else
sampling_params.extra_args.get("kv_transfer_params"))
self.kv_transfer_params: Optional[dict[str, Any]] = kv_params
# Sanity check
assert len(self.mm_inputs) == len(self.mm_positions)