[NIXL] Add remote_request_id to kv_transfer_params (#29665)

Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
Mark McLoughlin 2025-12-05 17:43:48 +00:00 committed by GitHub
parent dc264bcea1
commit dff0a2b394
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 21 additions and 3 deletions

View File

@ -470,6 +470,7 @@ class TestNixlHandshake:
num_xfers + 6,
],
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_request_id": f"prefill-{request_id}",
"remote_host": "localhost",
"remote_port": 1234,
"remote_tp_size": 1,
@ -536,6 +537,7 @@ class TestNixlHandshake:
kv_transfer_params={
"remote_block_ids": [4, 5, 6],
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_request_id": "prefill-id",
"remote_host": "localhost",
"remote_port": 1234,
"remote_tp_size": prefill_tp_size,
@ -591,6 +593,7 @@ class TestNixlHandshake:
kv_transfer_params={
"remote_block_ids": [4, 5, 6],
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_request_id": f"prefill-id-{i}",
"remote_host": "localhost",
"remote_port": 1234,
"remote_tp_size": 1,
@ -754,6 +757,7 @@ def test_kv_connector_stats(dist_init):
kv_transfer_params={
"remote_block_ids": [4, 5, 6],
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_request_id": f"prefill-{request_id}",
"remote_host": "localhost",
"remote_port": 1234,
"remote_tp_size": 1,
@ -1470,6 +1474,7 @@ def test_handshake_failure_returns_finished(dist_init):
kv_transfer_params={
"remote_block_ids": [4, 5, 6],
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_request_id": f"prefill-{request_id}",
"remote_host": "localhost",
"remote_port": 1234,
"remote_tp_size": 1,
@ -1519,6 +1524,7 @@ def test_transfer_setup_failure_returns_finished(dist_init):
kv_transfer_params={
"remote_block_ids": [10, 11, 12],
"remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_request_id": f"prefill-{request_id}",
"remote_host": "localhost",
"remote_port": 1234,
"remote_tp_size": 1,

View File

@ -194,6 +194,7 @@ def create_request(
do_remote_prefill=True,
do_remote_decode=False,
remote_engine_id="my-engine-id",
remote_request_id=f"prefill-{request_id}",
remote_block_ids=list(range(num_remote_blocks)),
remote_host="my-host",
remote_port=1234,

View File

@ -71,8 +71,9 @@ ReqId = str
#
# Version History:
# 1: Initial version with compatibility checking
# 2: Add remote_request_id to kv_transfer_params
#
NIXL_CONNECTOR_VERSION: int = 1
NIXL_CONNECTOR_VERSION: int = 2
GET_META_MSG = b"get_meta_msg"
@ -210,6 +211,7 @@ class ReqMeta:
remote_host: str
remote_port: int
remote_engine_id: str
remote_request_id: str
tp_size: int
@ -236,6 +238,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
local_physical_block_ids=local_block_ids,
remote_block_ids=kv_transfer_params["remote_block_ids"],
remote_engine_id=kv_transfer_params["remote_engine_id"],
remote_request_id=kv_transfer_params["remote_request_id"],
remote_host=kv_transfer_params["remote_host"],
remote_port=kv_transfer_params["remote_port"],
# P workers don't need to receive tp_size from proxy here.
@ -622,7 +625,12 @@ class NixlConnectorScheduler:
if params.get("remote_block_ids"):
if all(
p in params
for p in ("remote_engine_id", "remote_host", "remote_port")
for p in (
"remote_engine_id",
"remote_request_id",
"remote_host",
"remote_port",
)
):
# If remote_blocks and num_external_tokens = 0, we have
# a full prefix cache hit on the D worker. We need to call
@ -751,6 +759,7 @@ class NixlConnectorScheduler:
do_remote_decode=False,
remote_block_ids=block_ids,
remote_engine_id=self.engine_id,
remote_request_id=request.request_id,
remote_host=self.side_channel_host,
remote_port=self.side_channel_port,
tp_size=self.vllm_config.parallel_config.tensor_parallel_size,
@ -1964,6 +1973,7 @@ class NixlConnectorWorker:
self._read_blocks(
request_id=req_id,
dst_engine_id=meta.remote_engine_id,
remote_request_id=meta.remote_request_id,
local_block_ids=meta.local_physical_block_ids,
remote_block_ids=meta.remote_block_ids,
)
@ -1974,6 +1984,7 @@ class NixlConnectorWorker:
remote_block_ids: list[int],
dst_engine_id: str,
request_id: str,
remote_request_id: str,
):
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id)
if block_size_ratio > 1:
@ -2006,7 +2017,7 @@ class NixlConnectorWorker:
# Number of D TP workers that will read from dst P. Propagate tp_ratio
# on notification so that dst worker can wait before freeing blocks.
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(dst_engine_id)
notif_id = f"{request_id}:{tp_ratio}".encode()
notif_id = f"{remote_request_id}:{tp_ratio}".encode()
# Full prefix cache hit: do not need to read remote blocks,
# just notify P worker that we have the blocks we need.