mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 20:15:01 +08:00
[KVConnector] Keep KVTransferParams as a dict (#18033)
This commit is contained in:
parent
d066e52013
commit
59dd311cf5
@ -1,13 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig,
|
||||
ModelConfig, SchedulerConfig, VllmConfig)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
||||
NixlKVTransferParams)
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec)
|
||||
@ -124,20 +122,20 @@ def create_request(
|
||||
) -> Request:
|
||||
"""Make dummy request for testing."""
|
||||
|
||||
kv_transfer_params: Optional[dict[str, Any]] = None
|
||||
|
||||
if do_remote_decode:
|
||||
assert not do_remote_prefill
|
||||
kv_transfer_params = NixlKVTransferParams(do_remote_prefill=False,
|
||||
do_remote_decode=True)
|
||||
kv_transfer_params = dict(do_remote_prefill=False,
|
||||
do_remote_decode=True)
|
||||
elif do_remote_prefill:
|
||||
kv_transfer_params = NixlKVTransferParams(
|
||||
do_remote_prefill=True,
|
||||
do_remote_decode=False,
|
||||
remote_engine_id="my-engine-id",
|
||||
remote_block_ids=list(range(num_remote_blocks)),
|
||||
remote_host="my-host",
|
||||
remote_port=1234)
|
||||
else:
|
||||
kv_transfer_params = None
|
||||
kv_transfer_params = dict(do_remote_prefill=True,
|
||||
do_remote_decode=False,
|
||||
remote_engine_id="my-engine-id",
|
||||
remote_block_ids=list(
|
||||
range(num_remote_blocks)),
|
||||
remote_host="my-host",
|
||||
remote_port=1234)
|
||||
|
||||
max_tokens = 1 if do_remote_decode else max_tokens
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorRole, KVTransferParams)
|
||||
KVConnectorBase_V1, KVConnectorRole)
|
||||
|
||||
__all__ = ["KVConnectorRole", "KVConnectorBase_V1", "KVTransferParams"]
|
||||
__all__ = ["KVConnectorRole", "KVConnectorBase_V1"]
|
||||
|
||||
@ -48,23 +48,6 @@ class KVConnectorRole(enum.Enum):
|
||||
WORKER = 1
|
||||
|
||||
|
||||
class KVTransferParams:
|
||||
"""
|
||||
Abstract KVTransferParams used to send KVTransfer
|
||||
parameters between instances of vLLM.
|
||||
|
||||
Specific instances of KVConnector customize this
|
||||
method for serializing / deserializing msgs sent
|
||||
via the HTTP protocol.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def from_raw_dict(
|
||||
raw_dict: Optional[dict[str,
|
||||
Any]]) -> Optional["KVTransferParams"]:
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVConnectorMetadata:
|
||||
"""
|
||||
@ -75,7 +58,6 @@ class KVConnectorMetadata:
|
||||
|
||||
|
||||
class KVConnectorBase_V1(ABC):
|
||||
_KVTransferParams = KVTransferParams
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
||||
logger.warning(
|
||||
@ -213,13 +195,6 @@ class KVConnectorBase_V1(ABC):
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
|
||||
def set_kv_transfer_params(self, request: "Request"):
|
||||
"""Parse raw KV Transfer params."""
|
||||
assert request.kv_transfer_params is None
|
||||
kv_transfer_params = self._KVTransferParams.from_raw_dict(
|
||||
request.raw_kv_transfer_params)
|
||||
request.kv_transfer_params = kv_transfer_params
|
||||
|
||||
@abstractmethod
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
|
||||
@ -16,7 +16,7 @@ import zmq
|
||||
from vllm import envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole, KVTransferParams)
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
|
||||
get_tp_group)
|
||||
@ -44,56 +44,6 @@ except ImportError:
|
||||
NixlWrapper = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class NixlKVTransferParams(KVTransferParams):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
do_remote_prefill: bool,
|
||||
do_remote_decode: bool,
|
||||
remote_block_ids: Optional[list[int]] = None,
|
||||
remote_host: Optional[str] = None,
|
||||
remote_port: Optional[int] = None,
|
||||
remote_engine_id: Optional[str] = None,
|
||||
):
|
||||
self.do_remote_prefill = do_remote_prefill
|
||||
self.do_remote_decode = do_remote_decode
|
||||
self.remote_block_ids = remote_block_ids
|
||||
self.remote_host = remote_host
|
||||
self.remote_port = remote_port
|
||||
self.remote_engine_id = remote_engine_id
|
||||
|
||||
@staticmethod
|
||||
def from_raw_dict(
|
||||
raw_dict: Optional[dict[str,
|
||||
Any]]) -> Optional["NixlKVTransferParams"]:
|
||||
|
||||
# If no raw transfer params passed, return None.
|
||||
if raw_dict is None:
|
||||
return None
|
||||
|
||||
# Validate the request is formatted properly.
|
||||
if (("do_remote_prefill" not in raw_dict)
|
||||
or ("do_remote_decode" not in raw_dict)
|
||||
or ("remote_block_ids" not in raw_dict)
|
||||
or ("remote_host" not in raw_dict)
|
||||
or ("remote_port" not in raw_dict)
|
||||
or ("remote_engine_id" not in raw_dict)):
|
||||
logger.warning(
|
||||
"Got invalid KVTransferParams: %s. This "
|
||||
"request will not utilize KVTransfer", raw_dict)
|
||||
return None
|
||||
|
||||
return NixlKVTransferParams(
|
||||
do_remote_prefill=raw_dict["do_remote_prefill"],
|
||||
do_remote_decode=raw_dict["do_remote_decode"],
|
||||
remote_block_ids=raw_dict["remote_block_ids"],
|
||||
remote_host=raw_dict["remote_host"],
|
||||
remote_port=raw_dict["remote_port"],
|
||||
remote_engine_id=raw_dict["remote_engine_id"],
|
||||
)
|
||||
|
||||
|
||||
class NixlAgentMetadata(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
@ -123,25 +73,18 @@ class NixlConnectorMetadata(KVConnectorMetadata):
|
||||
self,
|
||||
request_id: str,
|
||||
local_block_ids: list[int],
|
||||
kv_transfer_params: NixlKVTransferParams,
|
||||
kv_transfer_params: dict[str, Any],
|
||||
):
|
||||
assert request_id not in self.requests
|
||||
assert kv_transfer_params.remote_block_ids is not None
|
||||
assert kv_transfer_params.remote_engine_id is not None
|
||||
assert kv_transfer_params.remote_host is not None
|
||||
assert kv_transfer_params.remote_port is not None
|
||||
|
||||
self.requests[request_id] = ReqMeta(
|
||||
local_block_ids=local_block_ids,
|
||||
remote_block_ids=kv_transfer_params.remote_block_ids,
|
||||
remote_engine_id=kv_transfer_params.remote_engine_id,
|
||||
remote_host=kv_transfer_params.remote_host,
|
||||
remote_port=kv_transfer_params.remote_port,
|
||||
remote_block_ids=kv_transfer_params["remote_block_ids"],
|
||||
remote_engine_id=kv_transfer_params["remote_engine_id"],
|
||||
remote_host=kv_transfer_params["remote_host"],
|
||||
remote_port=kv_transfer_params["remote_port"],
|
||||
)
|
||||
|
||||
|
||||
class NixlConnector(KVConnectorBase_V1):
|
||||
_KVTransferParams: type[NixlKVTransferParams] = NixlKVTransferParams
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
@ -253,52 +196,52 @@ class NixlConnectorScheduler:
|
||||
asynchronously (between scheduler steps).
|
||||
"""
|
||||
|
||||
params = request.kv_transfer_params
|
||||
logger.debug(
|
||||
"NIXLConnector get_num_new_matched_tokens: "
|
||||
"num_computed_tokens=%s, kv_transfer_params=%s",
|
||||
num_computed_tokens, request.kv_transfer_params)
|
||||
num_computed_tokens, params)
|
||||
|
||||
# No KVTransfer for this request.
|
||||
if request.kv_transfer_params is None:
|
||||
return 0, False
|
||||
assert isinstance(request.kv_transfer_params, NixlKVTransferParams)
|
||||
|
||||
# Remote prefill: get all prompt blocks from remote.
|
||||
if request.kv_transfer_params.do_remote_prefill:
|
||||
if params is not None and params.get("do_remote_prefill"):
|
||||
# Remote prefill: get all prompt blocks from remote.
|
||||
assert num_computed_tokens % self.block_size == 0
|
||||
rounded_num_prompt_tokens = round_down(
|
||||
len(request.prompt_token_ids), self.block_size)
|
||||
count = max(rounded_num_prompt_tokens - num_computed_tokens, 0)
|
||||
return count, count > 0
|
||||
|
||||
# No remote prefill for this request.
|
||||
return 0, False
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
|
||||
params = request.kv_transfer_params
|
||||
logger.debug(
|
||||
"NIXLConnector update_state_after_alloc: "
|
||||
"num_external_tokens=%s, kv_transfer_params=%s",
|
||||
num_external_tokens, request.kv_transfer_params)
|
||||
num_external_tokens, params)
|
||||
|
||||
if request.kv_transfer_params is None:
|
||||
return
|
||||
|
||||
assert isinstance(request.kv_transfer_params, NixlKVTransferParams)
|
||||
if request.kv_transfer_params.do_remote_prefill:
|
||||
if params is not None and params.get("do_remote_prefill"):
|
||||
# NOTE(rob): if prompt < block_size, no remote blocks
|
||||
# since the remote only sends fully computed blocks, so
|
||||
# skip recving for this request. num_external_tokens
|
||||
# should be 0 if there are no remote blocks.
|
||||
if request.kv_transfer_params.remote_block_ids:
|
||||
# Get unhashed blocks to pull from remote.
|
||||
self._reqs_need_recv[request.request_id] = (
|
||||
request, blocks.get_unhashed_block_ids())
|
||||
if params.get("remote_block_ids"):
|
||||
if all(p in params for p in ("remote_engine_id", "remote_host",
|
||||
"remote_port")):
|
||||
# Get unhashed blocks to pull from remote.
|
||||
self._reqs_need_recv[request.request_id] = (
|
||||
request, blocks.get_unhashed_block_ids())
|
||||
else:
|
||||
logger.warning(
|
||||
"Got invalid KVTransferParams: %s. This "
|
||||
"request will not utilize KVTransfer", params)
|
||||
else:
|
||||
assert num_external_tokens == 0
|
||||
# Only trigger 1 KV transfer per request.
|
||||
request.kv_transfer_params.do_remote_prefill = False
|
||||
params["do_remote_prefill"] = False
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
@ -308,7 +251,7 @@ class NixlConnectorScheduler:
|
||||
|
||||
# Loop through scheduled reqs and convert to ReqMeta.
|
||||
for req_id, (req, block_ids) in self._reqs_need_recv.items():
|
||||
assert isinstance(req.kv_transfer_params, NixlKVTransferParams)
|
||||
assert req.kv_transfer_params is not None
|
||||
meta.add_new_req(
|
||||
request_id=req_id,
|
||||
local_block_ids=block_ids,
|
||||
@ -330,34 +273,30 @@ class NixlConnectorScheduler:
|
||||
should be freed now or will be sent asynchronously and freed later.
|
||||
"""
|
||||
|
||||
params = request.kv_transfer_params
|
||||
logger.debug(
|
||||
"NIXLConnector request_finished, "
|
||||
"request_status=%s, kv_transfer_params=%s", request.status,
|
||||
request.kv_transfer_params)
|
||||
"NIXLConnector request_finished, request_status=%s, "
|
||||
"kv_transfer_params=%s", request.status, params)
|
||||
|
||||
if request.kv_transfer_params is None:
|
||||
return False, None
|
||||
assert isinstance(request.kv_transfer_params, NixlKVTransferParams)
|
||||
|
||||
if ((not request.kv_transfer_params.do_remote_decode)
|
||||
or (request.status != RequestStatus.FINISHED_LENGTH_CAPPED)):
|
||||
if (params is None or not params.get("do_remote_decode")
|
||||
or request.status != RequestStatus.FINISHED_LENGTH_CAPPED):
|
||||
return False, None
|
||||
|
||||
# Get computed blocks.
|
||||
all_full = request.num_computed_tokens % self.block_size == 0
|
||||
computed_block_ids = (block_ids if all_full else block_ids[:-1])
|
||||
computed_block_ids = block_ids if all_full else block_ids[:-1]
|
||||
|
||||
# If prompt < block_size, no xfer so free blocks immediately.
|
||||
delay_free_blocks = len(computed_block_ids) > 0
|
||||
|
||||
return delay_free_blocks, NixlKVTransferParams(
|
||||
return delay_free_blocks, dict(
|
||||
do_remote_prefill=True,
|
||||
do_remote_decode=False,
|
||||
remote_block_ids=computed_block_ids,
|
||||
remote_engine_id=self.engine_id,
|
||||
remote_host=envs.VLLM_NIXL_SIDE_CHANNEL_HOST,
|
||||
remote_port=envs.VLLM_NIXL_SIDE_CHANNEL_PORT,
|
||||
).__dict__
|
||||
)
|
||||
|
||||
|
||||
class NixlConnectorWorker:
|
||||
|
||||
@ -12,8 +12,7 @@ from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||
KVConnectorFactory)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
|
||||
KVConnectorRole,
|
||||
KVTransferParams)
|
||||
KVConnectorRole)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
|
||||
@ -931,8 +930,13 @@ class Scheduler(SchedulerInterface):
|
||||
return self.connector
|
||||
|
||||
def _connector_finished(
|
||||
self, request: Request) -> tuple[bool, Optional[KVTransferParams]]:
|
||||
"""Invoke the KV connector request_finished() method if applicable."""
|
||||
self, request: Request) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
"""
|
||||
Invoke the KV connector request_finished() method if applicable.
|
||||
|
||||
Returns optional kv transfer parameters to be included with the
|
||||
request outputs.
|
||||
"""
|
||||
if self.connector is None:
|
||||
return False, None
|
||||
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)
|
||||
|
||||
@ -182,14 +182,10 @@ class EngineCore:
|
||||
# Start grammar compilation asynchronously
|
||||
self.structured_output_manager.grammar_init(req)
|
||||
|
||||
if req.raw_kv_transfer_params is not None:
|
||||
if (kv_connector := self.scheduler.get_kv_connector()):
|
||||
# Parse raw KV transfer params via connector.
|
||||
kv_connector.set_kv_transfer_params(req)
|
||||
else:
|
||||
logger.warning(
|
||||
"Got KVTransferParams, but no KVConnector found. "
|
||||
"Disabling KVTransfer for this request.")
|
||||
if req.kv_transfer_params is not None and (
|
||||
not self.scheduler.get_kv_connector()):
|
||||
logger.warning("Got kv_transfer_params, but no KVConnector found. "
|
||||
"Disabling KVTransfer for this request.")
|
||||
|
||||
self.scheduler.add_request(req)
|
||||
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
import enum
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVTransferParams
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import is_list_of
|
||||
@ -62,14 +61,10 @@ class Request:
|
||||
self.num_encoder_inputs = len(self.mm_inputs)
|
||||
self.has_encoder_inputs = self.num_encoder_inputs > 0
|
||||
|
||||
# P/D: KV transfer parameters (raw and parsed).
|
||||
raw_params = (None if sampling_params.extra_args is None
|
||||
else sampling_params.extra_args.get(
|
||||
"kv_transfer_params", None))
|
||||
self.raw_kv_transfer_params: Optional[dict[str, Any]] = raw_params
|
||||
# Each connector parses the raw dictionary and sets this
|
||||
# attr the first time that the request is processed.
|
||||
self.kv_transfer_params: Optional[KVTransferParams] = None
|
||||
# P/D: Connector-specific KV transfer parameters.
|
||||
kv_params = (None if sampling_params.extra_args is None else
|
||||
sampling_params.extra_args.get("kv_transfer_params"))
|
||||
self.kv_transfer_params: Optional[dict[str, Any]] = kv_params
|
||||
|
||||
# Sanity check
|
||||
assert len(self.mm_inputs) == len(self.mm_positions)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user