diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 8a7d7bdd83da6..53e2d6fda1aea 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -1,13 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Optional +from typing import Any, Optional import torch from vllm import SamplingParams from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig, ModelConfig, SchedulerConfig, VllmConfig) -from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( - NixlKVTransferParams) from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec) @@ -124,20 +122,20 @@ def create_request( ) -> Request: """Make dummy request for testing.""" + kv_transfer_params: Optional[dict[str, Any]] = None + if do_remote_decode: assert not do_remote_prefill - kv_transfer_params = NixlKVTransferParams(do_remote_prefill=False, - do_remote_decode=True) + kv_transfer_params = dict(do_remote_prefill=False, + do_remote_decode=True) elif do_remote_prefill: - kv_transfer_params = NixlKVTransferParams( - do_remote_prefill=True, - do_remote_decode=False, - remote_engine_id="my-engine-id", - remote_block_ids=list(range(num_remote_blocks)), - remote_host="my-host", - remote_port=1234) - else: - kv_transfer_params = None + kv_transfer_params = dict(do_remote_prefill=True, + do_remote_decode=False, + remote_engine_id="my-engine-id", + remote_block_ids=list( + range(num_remote_blocks)), + remote_host="my-host", + remote_port=1234) max_tokens = 1 if do_remote_decode else max_tokens sampling_params = SamplingParams(max_tokens=max_tokens) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py b/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py index 43181ab79afc9..e66aaa7f8af8e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorRole, KVTransferParams) + KVConnectorBase_V1, KVConnectorRole) -__all__ = ["KVConnectorRole", "KVConnectorBase_V1", "KVTransferParams"] +__all__ = ["KVConnectorRole", "KVConnectorBase_V1"] diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 2ff61e8a400f0..03c99f20e775b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -48,23 +48,6 @@ class KVConnectorRole(enum.Enum): WORKER = 1 -class KVTransferParams: - """ - Abstract KVTransferParams used to send KVTransfer - parameters between instances of vLLM. - - Specific instances of KVConnector customize this - method for serializing / deserializing msgs sent - via the HTTP protocol. - """ - - @staticmethod - def from_raw_dict( - raw_dict: Optional[dict[str, - Any]]) -> Optional["KVTransferParams"]: - return None - - @dataclass class KVConnectorMetadata: """ @@ -75,7 +58,6 @@ class KVConnectorMetadata: class KVConnectorBase_V1(ABC): - _KVTransferParams = KVTransferParams def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): logger.warning( @@ -213,13 +195,6 @@ class KVConnectorBase_V1(ABC): # Scheduler-side methods # ============================== - def set_kv_transfer_params(self, request: "Request"): - """Parse raw KV Transfer params.""" - assert request.kv_transfer_params is None - kv_transfer_params = self._KVTransferParams.from_raw_dict( - request.raw_kv_transfer_params) - request.kv_transfer_params = kv_transfer_params - @abstractmethod def get_num_new_matched_tokens( self, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 6e6add0825c22..abd1ea2bea82b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -16,7 +16,7 @@ import zmq from vllm import envs from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole, KVTransferParams) + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tp_group) @@ -44,56 +44,6 @@ except ImportError: NixlWrapper = None -@dataclass -class NixlKVTransferParams(KVTransferParams): - - def __init__( - self, - do_remote_prefill: bool, - do_remote_decode: bool, - remote_block_ids: Optional[list[int]] = None, - remote_host: Optional[str] = None, - remote_port: Optional[int] = None, - remote_engine_id: Optional[str] = None, - ): - self.do_remote_prefill = do_remote_prefill - self.do_remote_decode = do_remote_decode - self.remote_block_ids = remote_block_ids - self.remote_host = remote_host - self.remote_port = remote_port - self.remote_engine_id = remote_engine_id - - @staticmethod - def from_raw_dict( - raw_dict: Optional[dict[str, - Any]]) -> Optional["NixlKVTransferParams"]: - - # If no raw transfer params passed, return None. - if raw_dict is None: - return None - - # Validate the request is formatted properly. - if (("do_remote_prefill" not in raw_dict) - or ("do_remote_decode" not in raw_dict) - or ("remote_block_ids" not in raw_dict) - or ("remote_host" not in raw_dict) - or ("remote_port" not in raw_dict) - or ("remote_engine_id" not in raw_dict)): - logger.warning( - "Got invalid KVTransferParams: %s. This " - "request will not utilize KVTransfer", raw_dict) - return None - - return NixlKVTransferParams( - do_remote_prefill=raw_dict["do_remote_prefill"], - do_remote_decode=raw_dict["do_remote_decode"], - remote_block_ids=raw_dict["remote_block_ids"], - remote_host=raw_dict["remote_host"], - remote_port=raw_dict["remote_port"], - remote_engine_id=raw_dict["remote_engine_id"], - ) - - class NixlAgentMetadata( msgspec.Struct, omit_defaults=True, # type: ignore[call-arg] @@ -123,25 +73,18 @@ class NixlConnectorMetadata(KVConnectorMetadata): self, request_id: str, local_block_ids: list[int], - kv_transfer_params: NixlKVTransferParams, + kv_transfer_params: dict[str, Any], ): - assert request_id not in self.requests - assert kv_transfer_params.remote_block_ids is not None - assert kv_transfer_params.remote_engine_id is not None - assert kv_transfer_params.remote_host is not None - assert kv_transfer_params.remote_port is not None - self.requests[request_id] = ReqMeta( local_block_ids=local_block_ids, - remote_block_ids=kv_transfer_params.remote_block_ids, - remote_engine_id=kv_transfer_params.remote_engine_id, - remote_host=kv_transfer_params.remote_host, - remote_port=kv_transfer_params.remote_port, + remote_block_ids=kv_transfer_params["remote_block_ids"], + remote_engine_id=kv_transfer_params["remote_engine_id"], + remote_host=kv_transfer_params["remote_host"], + remote_port=kv_transfer_params["remote_port"], ) class NixlConnector(KVConnectorBase_V1): - _KVTransferParams: type[NixlKVTransferParams] = NixlKVTransferParams def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): assert vllm_config.kv_transfer_config is not None @@ -253,52 +196,52 @@ class NixlConnectorScheduler: asynchronously (between scheduler steps). """ + params = request.kv_transfer_params logger.debug( "NIXLConnector get_num_new_matched_tokens: " "num_computed_tokens=%s, kv_transfer_params=%s", - num_computed_tokens, request.kv_transfer_params) + num_computed_tokens, params) - # No KVTransfer for this request. - if request.kv_transfer_params is None: - return 0, False - assert isinstance(request.kv_transfer_params, NixlKVTransferParams) - - # Remote prefill: get all prompt blocks from remote. - if request.kv_transfer_params.do_remote_prefill: + if params is not None and params.get("do_remote_prefill"): + # Remote prefill: get all prompt blocks from remote. assert num_computed_tokens % self.block_size == 0 rounded_num_prompt_tokens = round_down( len(request.prompt_token_ids), self.block_size) count = max(rounded_num_prompt_tokens - num_computed_tokens, 0) return count, count > 0 + # No remote prefill for this request. return 0, False def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int): + params = request.kv_transfer_params logger.debug( "NIXLConnector update_state_after_alloc: " "num_external_tokens=%s, kv_transfer_params=%s", - num_external_tokens, request.kv_transfer_params) + num_external_tokens, params) - if request.kv_transfer_params is None: - return - - assert isinstance(request.kv_transfer_params, NixlKVTransferParams) - if request.kv_transfer_params.do_remote_prefill: + if params is not None and params.get("do_remote_prefill"): # NOTE(rob): if prompt < block_size, no remote blocks # since the remote only sends fully computed blocks, so # skip recving for this request. num_external_tokens # should be 0 if there are no remote blocks. - if request.kv_transfer_params.remote_block_ids: - # Get unhashed blocks to pull from remote. - self._reqs_need_recv[request.request_id] = ( - request, blocks.get_unhashed_block_ids()) + if params.get("remote_block_ids"): + if all(p in params for p in ("remote_engine_id", "remote_host", + "remote_port")): + # Get unhashed blocks to pull from remote. + self._reqs_need_recv[request.request_id] = ( + request, blocks.get_unhashed_block_ids()) + else: + logger.warning( + "Got invalid KVTransferParams: %s. This " + "request will not utilize KVTransfer", params) else: assert num_external_tokens == 0 # Only trigger 1 KV transfer per request. - request.kv_transfer_params.do_remote_prefill = False + params["do_remote_prefill"] = False def build_connector_meta( self, @@ -308,7 +251,7 @@ class NixlConnectorScheduler: # Loop through scheduled reqs and convert to ReqMeta. for req_id, (req, block_ids) in self._reqs_need_recv.items(): - assert isinstance(req.kv_transfer_params, NixlKVTransferParams) + assert req.kv_transfer_params is not None meta.add_new_req( request_id=req_id, local_block_ids=block_ids, @@ -330,34 +273,30 @@ class NixlConnectorScheduler: should be freed now or will be sent asynchronously and freed later. """ + params = request.kv_transfer_params logger.debug( - "NIXLConnector request_finished, " - "request_status=%s, kv_transfer_params=%s", request.status, - request.kv_transfer_params) + "NIXLConnector request_finished, request_status=%s, " + "kv_transfer_params=%s", request.status, params) - if request.kv_transfer_params is None: - return False, None - assert isinstance(request.kv_transfer_params, NixlKVTransferParams) - - if ((not request.kv_transfer_params.do_remote_decode) - or (request.status != RequestStatus.FINISHED_LENGTH_CAPPED)): + if (params is None or not params.get("do_remote_decode") + or request.status != RequestStatus.FINISHED_LENGTH_CAPPED): return False, None # Get computed blocks. all_full = request.num_computed_tokens % self.block_size == 0 - computed_block_ids = (block_ids if all_full else block_ids[:-1]) + computed_block_ids = block_ids if all_full else block_ids[:-1] # If prompt < block_size, no xfer so free blocks immediately. delay_free_blocks = len(computed_block_ids) > 0 - return delay_free_blocks, NixlKVTransferParams( + return delay_free_blocks, dict( do_remote_prefill=True, do_remote_decode=False, remote_block_ids=computed_block_ids, remote_engine_id=self.engine_id, remote_host=envs.VLLM_NIXL_SIDE_CHANNEL_HOST, remote_port=envs.VLLM_NIXL_SIDE_CHANNEL_PORT, - ).__dict__ + ) class NixlConnectorWorker: diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 9f051b73c263d..f338e4ba14400 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -12,8 +12,7 @@ from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch from vllm.distributed.kv_transfer.kv_connector.factory import ( KVConnectorFactory) from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, - KVConnectorRole, - KVTransferParams) + KVConnectorRole) from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, @@ -931,8 +930,13 @@ class Scheduler(SchedulerInterface): return self.connector def _connector_finished( - self, request: Request) -> tuple[bool, Optional[KVTransferParams]]: - """Invoke the KV connector request_finished() method if applicable.""" + self, request: Request) -> tuple[bool, Optional[dict[str, Any]]]: + """ + Invoke the KV connector request_finished() method if applicable. + + Returns optional kv transfer parameters to be included with the + request outputs. + """ if self.connector is None: return False, None block_ids = self.kv_cache_manager.get_block_ids(request.request_id) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index edc79ae20b9f2..0cf2383af1c9b 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -182,14 +182,10 @@ class EngineCore: # Start grammar compilation asynchronously self.structured_output_manager.grammar_init(req) - if req.raw_kv_transfer_params is not None: - if (kv_connector := self.scheduler.get_kv_connector()): - # Parse raw KV transfer params via connector. - kv_connector.set_kv_transfer_params(req) - else: - logger.warning( - "Got KVTransferParams, but no KVConnector found. " - "Disabling KVTransfer for this request.") + if req.kv_transfer_params is not None and ( + not self.scheduler.get_kv_connector()): + logger.warning("Got kv_transfer_params, but no KVConnector found. " + "Disabling KVTransfer for this request.") self.scheduler.add_request(req) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index fc6b738546f40..d2843b65ab59c 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -3,7 +3,6 @@ import enum from typing import TYPE_CHECKING, Any, Optional, Union -from vllm.distributed.kv_transfer.kv_connector.v1 import KVTransferParams from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams from vllm.utils import is_list_of @@ -62,14 +61,10 @@ class Request: self.num_encoder_inputs = len(self.mm_inputs) self.has_encoder_inputs = self.num_encoder_inputs > 0 - # P/D: KV transfer parameters (raw and parsed). - raw_params = (None if sampling_params.extra_args is None - else sampling_params.extra_args.get( - "kv_transfer_params", None)) - self.raw_kv_transfer_params: Optional[dict[str, Any]] = raw_params - # Each connector parses the raw dictionary and sets this - # attr the first time that the request is processed. - self.kv_transfer_params: Optional[KVTransferParams] = None + # P/D: Connector-specific KV transfer parameters. + kv_params = (None if sampling_params.extra_args is None else + sampling_params.extra_args.get("kv_transfer_params")) + self.kv_transfer_params: Optional[dict[str, Any]] = kv_params # Sanity check assert len(self.mm_inputs) == len(self.mm_positions)