diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 65db16f48c2c9..5045ae0eef33f 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -470,6 +470,7 @@ class TestNixlHandshake: num_xfers + 6, ], "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + "remote_request_id": f"prefill-{request_id}", "remote_host": "localhost", "remote_port": 1234, "remote_tp_size": 1, @@ -536,6 +537,7 @@ class TestNixlHandshake: kv_transfer_params={ "remote_block_ids": [4, 5, 6], "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + "remote_request_id": "prefill-id", "remote_host": "localhost", "remote_port": 1234, "remote_tp_size": prefill_tp_size, @@ -591,6 +593,7 @@ class TestNixlHandshake: kv_transfer_params={ "remote_block_ids": [4, 5, 6], "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + "remote_request_id": f"prefill-id-{i}", "remote_host": "localhost", "remote_port": 1234, "remote_tp_size": 1, @@ -754,6 +757,7 @@ def test_kv_connector_stats(dist_init): kv_transfer_params={ "remote_block_ids": [4, 5, 6], "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + "remote_request_id": f"prefill-{request_id}", "remote_host": "localhost", "remote_port": 1234, "remote_tp_size": 1, @@ -1470,6 +1474,7 @@ def test_handshake_failure_returns_finished(dist_init): kv_transfer_params={ "remote_block_ids": [4, 5, 6], "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + "remote_request_id": f"prefill-{request_id}", "remote_host": "localhost", "remote_port": 1234, "remote_tp_size": 1, @@ -1519,6 +1524,7 @@ def test_transfer_setup_failure_returns_finished(dist_init): kv_transfer_params={ "remote_block_ids": [10, 11, 12], "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + "remote_request_id": f"prefill-{request_id}", "remote_host": "localhost", "remote_port": 1234, "remote_tp_size": 1, diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index cea41c3ab18a6..58f1a7282352b 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -194,6 +194,7 @@ def create_request( do_remote_prefill=True, do_remote_decode=False, remote_engine_id="my-engine-id", + remote_request_id=f"prefill-{request_id}", remote_block_ids=list(range(num_remote_blocks)), remote_host="my-host", remote_port=1234, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 649e54adaba49..7aa12e9993c76 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -71,8 +71,9 @@ ReqId = str # # Version History: # 1: Initial version with compatibility checking +# 2: Add remote_request_id to kv_transfer_params # -NIXL_CONNECTOR_VERSION: int = 1 +NIXL_CONNECTOR_VERSION: int = 2 GET_META_MSG = b"get_meta_msg" @@ -210,6 +211,7 @@ class ReqMeta: remote_host: str remote_port: int remote_engine_id: str + remote_request_id: str tp_size: int @@ -236,6 +238,7 @@ class NixlConnectorMetadata(KVConnectorMetadata): local_physical_block_ids=local_block_ids, remote_block_ids=kv_transfer_params["remote_block_ids"], remote_engine_id=kv_transfer_params["remote_engine_id"], + remote_request_id=kv_transfer_params["remote_request_id"], remote_host=kv_transfer_params["remote_host"], remote_port=kv_transfer_params["remote_port"], # P workers don't need to receive tp_size from proxy here. @@ -622,7 +625,12 @@ class NixlConnectorScheduler: if params.get("remote_block_ids"): if all( p in params - for p in ("remote_engine_id", "remote_host", "remote_port") + for p in ( + "remote_engine_id", + "remote_request_id", + "remote_host", + "remote_port", + ) ): # If remote_blocks and num_external_tokens = 0, we have # a full prefix cache hit on the D worker. We need to call @@ -751,6 +759,7 @@ class NixlConnectorScheduler: do_remote_decode=False, remote_block_ids=block_ids, remote_engine_id=self.engine_id, + remote_request_id=request.request_id, remote_host=self.side_channel_host, remote_port=self.side_channel_port, tp_size=self.vllm_config.parallel_config.tensor_parallel_size, @@ -1964,6 +1973,7 @@ class NixlConnectorWorker: self._read_blocks( request_id=req_id, dst_engine_id=meta.remote_engine_id, + remote_request_id=meta.remote_request_id, local_block_ids=meta.local_physical_block_ids, remote_block_ids=meta.remote_block_ids, ) @@ -1974,6 +1984,7 @@ class NixlConnectorWorker: remote_block_ids: list[int], dst_engine_id: str, request_id: str, + remote_request_id: str, ): block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id) if block_size_ratio > 1: @@ -2006,7 +2017,7 @@ class NixlConnectorWorker: # Number of D TP workers that will read from dst P. Propagate tp_ratio # on notification so that dst worker can wait before freeing blocks. tp_ratio = self.kv_topo.tp_ratio_from_engine_id(dst_engine_id) - notif_id = f"{request_id}:{tp_ratio}".encode() + notif_id = f"{remote_request_id}:{tp_ratio}".encode() # Full prefix cache hit: do not need to read remote blocks, # just notify P worker that we have the blocks we need.