From d6517be3cd06111ada0a603acaeab28dd4580641 Mon Sep 17 00:00:00 2001 From: GuanLuo <41310872+GuanLuo@users.noreply.github.com> Date: Sat, 1 Nov 2025 01:16:00 +0800 Subject: [PATCH] [Bugfix] Missing NIXL metadata for handshake initialization if instance spans multi-node (#26338) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Guan Luo Signed-off-by: GuanLuo <41310872+GuanLuo@users.noreply.github.com> Signed-off-by: Guan Luo <41310872+GuanLuo@users.noreply.github.com> Co-authored-by: Nicolò Lucchesi --- docs/features/nixl_connector_usage.md | 2 +- .../kv_connector/unit/test_nixl_connector.py | 106 ++++++++- .../kv_transfer/kv_connector/v1/base.py | 32 +++ .../kv_connector/v1/nixl_connector.py | 224 +++++++++++------- vllm/v1/engine/core.py | 23 +- vllm/v1/executor/abstract.py | 8 + vllm/v1/worker/gpu_worker.py | 21 +- 7 files changed, 321 insertions(+), 95 deletions(-) diff --git a/docs/features/nixl_connector_usage.md b/docs/features/nixl_connector_usage.md index 605398652ee0b..1ce038f4d6525 100644 --- a/docs/features/nixl_connector_usage.md +++ b/docs/features/nixl_connector_usage.md @@ -81,7 +81,7 @@ python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py \ - Default: 5600 - **Required for both prefiller and decoder instances** - Each vLLM worker needs a unique port on its host; using the same port number across different hosts is fine - - For TP/DP deployments, each worker's port on a node is computed as: base_port + dp_rank * tp_size + tp_rank (e.g., with `--tensor-parallel-size=4` and base_port=5600, tp_rank 0..3 use ports 5600, 5601, 5602, 5603 on that node). + - For TP/DP deployments, each worker's port on a node is computed as: base_port + dp_rank (e.g., with `--data-parallel-size=2` and base_port=5600, dp_rank 0..1 use port 5600, 5601 on that node). - Used for the initial NIXL handshake between the prefiller and the decoder - `VLLM_NIXL_SIDE_CHANNEL_HOST`: Host for side channel communication diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 445d115010cdf..44d8b3e331fdb 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -27,6 +27,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( NixlAgentMetadata, NixlConnector, NixlConnectorMetadata, + NixlConnectorScheduler, NixlConnectorWorker, NixlKVConnectorStats, ) @@ -283,6 +284,92 @@ def test_prompt_less_than_block_size(): assert len(scheduler_output.scheduled_new_reqs) == 0 +@patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper, +) +def test_kv_transfer_handshake(dist_init): + """Unit test for basic NixlConnector interface functionality.""" + + # Test setup, we creates a scheduler that contains a NixlConnector + # of role SCHEDULER, and expect it to be serving NixlAgentMetadata from + # all workers of the instance. + vllm_config = create_vllm_config() + # in case the test runs on non-GPU machine + vllm_config.kv_transfer_config.kv_buffer_device = "cpu" + scheduler = create_scheduler(vllm_config) + + # Create two NixlConnector of role WORKER, one is the worker of + # the scheduler (prefill), the other is a worker of decode instance. + + # Prefill connector will register KV cache to populate proper handshake + # metadata. + prefill_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape( + num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 + ) + shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + kv_caches = { + "layer0": shared_tensor, + "layer1": unique_tensor, + "layer2": shared_tensor, + } + prefill_connector.register_kv_caches(kv_caches) + + # Simulate EngineCore initialization that would + # gather connector metadata from all workers, the scheduler connector + # expects metadata to be in dict[int, KVConnectorHandshakeMetadata], + # where the first key is the dp_rank, the second key is the tp_rank. + metadata = {0: prefill_connector.get_handshake_metadata()} + scheduler_connector = scheduler.get_kv_connector() + scheduler_connector.set_xfer_handshake_metadata(metadata) + + # Simulate a request that finishes prefill, which returns + # corresponding NixlConnectorMetadata for decode instance. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + request = create_request( + request_id=1, + block_size=BLOCK_SIZE, + num_tokens=NUM_TOKENS, + do_remote_decode=True, + ) + request.status = RequestStatus.FINISHED_LENGTH_CAPPED + delay, kv_connector_metadata = scheduler.get_kv_connector().request_finished( + request, [0, 1, 2] + ) + assert delay + + # Decode connector will be able to create handshake with the prefill connector. + decode_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + + # Here we are testing the retrieval of NIXLAgentMetadata. + # Knowing the implementation detail, we override the add_remote_agent + # to validate the metadata received is the same as the one in prefill_connector. + with patch.object( + decode_connector.connector_worker, "add_remote_agent" + ) as mock_add_remote_agent: + mock_add_remote_agent.return_type = "remote_agent" + + decode_connector.connector_worker._nixl_handshake( + kv_connector_metadata["remote_host"], + kv_connector_metadata["remote_port"], + kv_connector_metadata["tp_size"], + kv_connector_metadata["remote_engine_id"], + ) + + received_metadata = mock_add_remote_agent.call_args.args + assert received_metadata[1] == 0 # remote_tp_rank + assert received_metadata[2] == 1 # remote_tp_size + assert metadata[0] == received_metadata[0] + + # Need to shutdown the background thread to release NIXL side channel port + scheduler_connector.shutdown() + + class FakeNixlConnectorWorker(NixlConnectorWorker): REMOTE_ENGINE_ID = "remote_engine" @@ -313,6 +400,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker): engine_id=self.REMOTE_ENGINE_ID, agent_metadata=FakeNixlWrapper.AGENT_METADATA, kv_caches_base_addr=[0], + device_id=0, num_blocks=1, block_lens=self.block_len_per_layer, attn_backend_name=self.backend_name, @@ -559,6 +647,7 @@ class TestNixlHandshake: engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID, agent_metadata=FakeNixlWrapper.AGENT_METADATA, kv_caches_base_addr=[0], + device_id=0, num_blocks=1, block_lens=worker.block_len_per_layer, attn_backend_name=worker.backend_name, @@ -611,6 +700,7 @@ class TestNixlHandshake: engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID, agent_metadata=FakeNixlWrapper.AGENT_METADATA, kv_caches_base_addr=[0], + device_id=0, num_blocks=1, # prefill TP=1, decode TP=2, remote block_lens is double to local block_lens=[i * 2 for i in worker.block_len_per_layer], @@ -1005,6 +1095,8 @@ def _run_abort_timeout_test(llm: LLM, timeout: int): _ = llm.generate([f"What is the capital of France? {padding}"], sampling_params) # Request-0 times out and is cleared! assert "0" not in req_to_blocks + # Need to shutdown the background thread to release NIXL side channel port + llm.llm_engine.engine_core.shutdown() def test_register_kv_caches(dist_init): @@ -1177,13 +1269,15 @@ def test_shutdown_cleans_up_resources(dist_init): """Test that shutdown() properly cleans up all resources.""" vllm_config = create_vllm_config() + scheduler = NixlConnectorScheduler( + vllm_config, vllm_config.kv_transfer_config.engine_id + ) worker = NixlConnectorWorker(vllm_config, vllm_config.kv_transfer_config.engine_id) nixl_wrapper = worker.nixl_wrapper with ( patch.object(worker, "_handshake_initiation_executor") as mock_exec, - patch.object(worker, "_nixl_handshake_listener_t") as mock_listener, - patch.object(worker, "_nixl_handshake_listener_stop_event") as mock_event, + patch.object(scheduler, "_nixl_handshake_listener_t") as mock_listener, patch.object(nixl_wrapper, "release_xfer_handle") as mock_rel_xfer, patch.object(nixl_wrapper, "release_dlist_handle") as mock_rel_dlist, patch.object(nixl_wrapper, "remove_remote_agent") as mock_rem_agent, @@ -1204,8 +1298,12 @@ def test_shutdown_cleans_up_resources(dist_init): worker.shutdown() mock_exec.shutdown.assert_called_with(wait=False) - mock_event.set.assert_called_once() - mock_listener.join.assert_called_once_with(timeout=1.0) + + # Same sequence on scheduler.shutdown() + scheduler.shutdown() + scheduler.shutdown() + scheduler.shutdown() + mock_listener.join.assert_called_once() mock_rel_xfer.assert_called_once_with(123) assert mock_rel_dlist.call_count == 2 diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 2ed0fe592e373..cb9f208a839f2 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -122,6 +122,15 @@ class KVConnectorRole(enum.Enum): WORKER = 1 +class KVConnectorHandshakeMetadata(ABC): # noqa: B024 + """ + Metadata used for out of band connector handshake between + P/D workers. This needs to serializeable. + """ + + pass + + class KVConnectorMetadata(ABC): # noqa: B024 """ Abstract Metadata used to communicate between the @@ -320,6 +329,18 @@ class KVConnectorBase_V1(ABC): """ return None + def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None: + """ + Get the KVConnector handshake metadata for this connector. + This metadata is used for out-of-band connector handshake + between P/D workers. + + Returns: + KVConnectorHandshakeMetadata: the handshake metadata. + None if no handshake metadata is available. + """ + return None + # ============================== # Scheduler-side methods # ============================== @@ -477,6 +498,17 @@ class KVConnectorBase_V1(ABC): """ return None + def set_xfer_handshake_metadata( + self, metadata: dict[int, KVConnectorHandshakeMetadata] + ) -> None: + """ + Set the KV connector handshake metadata for this connector. + + Args: + metadata (KVConnectorHandshakeMetadata): the handshake metadata to set. + """ + return None + @classmethod def build_prom_metrics( cls, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index d5712bdd9feb4..4651cedbc7dfa 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -27,6 +27,7 @@ from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( CopyBlocksOp, KVConnectorBase_V1, + KVConnectorHandshakeMetadata, KVConnectorMetadata, KVConnectorRole, ) @@ -93,15 +94,12 @@ _NIXL_SUPPORTED_DEVICE = { _NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices()) -class NixlAgentMetadata( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - # required for @cached_property. - dict=True, -): +@dataclass +class NixlAgentMetadata(KVConnectorHandshakeMetadata): engine_id: str agent_metadata: bytes kv_caches_base_addr: list[int] + device_id: int num_blocks: int block_lens: list[int] attn_backend_name: str @@ -223,6 +221,18 @@ class NixlConnector(KVConnectorBase_V1): assert self.connector_scheduler is not None return self.connector_scheduler.request_finished(request, block_ids) + def set_xfer_handshake_metadata( + self, metadata: dict[int, KVConnectorHandshakeMetadata] + ) -> None: + """ + Set the KV connector handshake metadata for this connector. + + Args: + metadata (dict): the handshake metadata to set. + """ + assert self.connector_scheduler is not None + self.connector_scheduler.set_xfer_handshake_metadata(metadata) + ############################################################ # Worker Side Methods ############################################################ @@ -299,6 +309,21 @@ class NixlConnector(KVConnectorBase_V1): def shutdown(self): if self.connector_worker is not None: self.connector_worker.shutdown() + if self.connector_scheduler is not None: + self.connector_scheduler.shutdown() + + def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None: + """ + Get the KVConnector handshake metadata for this connector. + This metadata is used for out-of-band connector handshake + between P/D workers. + + Returns: + KVConnectorHandshakeMetadata: the handshake metadata. + None if no handshake metadata is available. + """ + assert self.connector_worker is not None + return self.connector_worker.xfer_handshake_metadata class NixlConnectorScheduler: @@ -312,12 +337,16 @@ class NixlConnectorScheduler: self.side_channel_port = ( envs.VLLM_NIXL_SIDE_CHANNEL_PORT + vllm_config.parallel_config.data_parallel_rank - * vllm_config.parallel_config.tensor_parallel_size ) assert vllm_config.kv_transfer_config is not None self.use_host_buffer = vllm_config.kv_transfer_config.kv_buffer_device == "cpu" logger.info("Initializing NIXL Scheduler %s", engine_id) + # Background thread for handling new handshake requests. + self._nixl_handshake_listener_t: threading.Thread | None = None + self._encoded_xfer_handshake_metadata: dict[int, Any] = {} + self._stop_event = threading.Event() + # Requests that need to start recv/send. # New requests are added by update_state_after_alloc in # the scheduler. Used to make metadata passed to Worker. @@ -330,6 +359,89 @@ class NixlConnectorScheduler: # remote prefill or aborted. self._reqs_not_processed: set[ReqId] = set() + def shutdown(self): + self._stop_event.set() + if self._nixl_handshake_listener_t is not None: + self._nixl_handshake_listener_t.join() + self._nixl_handshake_listener_t = None + + def set_xfer_handshake_metadata( + self, metadata: dict[int, KVConnectorHandshakeMetadata] + ) -> None: + """ + Set the KV connector handshake metadata for this connector. + + Args: + metadata (dict): the handshake metadata to set. + """ + encoded_data: dict[int, bytes] = {} + encoder = msgspec.msgpack.Encoder() + for tp_rank, rank_metadata in metadata.items(): + if not isinstance(rank_metadata, NixlAgentMetadata): + raise ValueError( + "NixlConnectorScheduler expects NixlAgentMetadata for " + "handshake metadata." + ) + encoded_data[tp_rank] = encoder.encode(rank_metadata) + logger.debug( + "Tp rank %d: encoded NixlAgentMetadata size: %s bytes", + tp_rank, + str(len(encoded_data[tp_rank])), + ) + self._encoded_xfer_handshake_metadata = encoded_data + + # Only start the listener when we have metadata to serve. + if self._nixl_handshake_listener_t is None: + ready_event = threading.Event() + self._nixl_handshake_listener_t = threading.Thread( + target=self._nixl_handshake_listener, + args=( + encoded_data, + ready_event, + self._stop_event, + self.side_channel_port, + ), + daemon=True, + name="nixl_handshake_listener", + ) + self._nixl_handshake_listener_t.start() + ready_event.wait() # Wait for listener ZMQ socket to be ready. + + @staticmethod + def _nixl_handshake_listener( + encoded_data: dict[int, Any], + ready_event: threading.Event, + stop_event: threading.Event, + port: int, + ): + """Background thread for getting new NIXL handshakes.""" + # NOTE(rob): this is a simple implementation. We will move + # to a better approach via HTTP endpoint soon. + + # Listen for new requests for metadata. + host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST + path = make_zmq_path("tcp", host, port) + logger.debug("Starting listening on path: %s", path) + with zmq_ctx(zmq.ROUTER, path) as sock: + sock.setsockopt(zmq.RCVTIMEO, 1000) + ready_event.set() + while True: + try: + identity, _, msg = sock.recv_multipart() + except zmq.Again: + if stop_event.is_set(): + break + continue + # Decode the message which contains (GET_META_MSG, rank) + msg, target_tp_rank = msgspec.msgpack.decode(msg) + logger.debug( + "Received message for tp rank %s", + target_tp_rank, + ) + if msg != GET_META_MSG: + logger.warning("Connection listener got unexpected message %s", msg) + sock.send_multipart((identity, b"", encoded_data[target_tp_rank])) + def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int ) -> tuple[int, bool]: @@ -537,8 +649,6 @@ class NixlConnectorScheduler: class NixlConnectorWorker: """Implementation of Worker side methods""" - _POLL_TIMEOUT = 0.1 # Handshake thread polls for stop event every 100ms - @dataclass class TpKVTopology: """ @@ -651,16 +761,6 @@ class NixlConnectorWorker: # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}. self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict) - # NIXL handshake port. - # NOTE(rob): Within a DP group, each DP rank gets its own - # base port (which is sent in the KVTransferParams). - # Each TP rank listens/queries on the base_port + tp_rank. - self.side_channel_port: int = ( - envs.VLLM_NIXL_SIDE_CHANNEL_PORT - + vllm_config.parallel_config.data_parallel_rank - * vllm_config.parallel_config.tensor_parallel_size - ) - # Metadata. self.engine_id: EngineId = engine_id self.tp_rank = get_tensor_model_parallel_rank() @@ -706,6 +806,7 @@ class NixlConnectorWorker: # Map of engine_id -> kv_caches_base_addr. For TP case, each local # rank will still only pull from a single remote TP worker. self.kv_caches_base_addr: dict[EngineId, list[int]] = {} + self.device_id: int = 0 # Number of NIXL regions. Currently one region per cache # (so 1 per layer for MLA, otherwise 2 per layer) @@ -736,9 +837,8 @@ class NixlConnectorWorker: # requests that skipped transfer (handshake or transfer failures) self._failed_recv_reqs: set[ReqId] = set() - # Background thread for handling new handshake requests. - self._nixl_handshake_listener_t: threading.Thread | None = None - self._nixl_handshake_listener_stop_event: threading.Event | None = None + # Handshake metadata of this worker for NIXL transfers. + self.xfer_handshake_metadata: NixlAgentMetadata | None = None # Background thread for initializing new NIXL handshakes. self._handshake_initiation_executor = ThreadPoolExecutor( # NIXL is not guaranteed to be thread-safe, limit 1 worker. @@ -790,42 +890,6 @@ class NixlConnectorWorker: total_num_kv_heads=self.model_config.get_total_num_kv_heads(), ) - @staticmethod - def _nixl_handshake_listener( - metadata: NixlAgentMetadata, - ready_event: threading.Event, - stop_event: threading.Event, - base_port: int, - tp_rank: int, - ): - """Background thread for getting new NIXL handshakes.""" - # NOTE(rob): this is a simple implementation. We will move - # to a better approach via HTTP endpoint soon. - - encoder = msgspec.msgpack.Encoder() - encoded_data = encoder.encode(metadata) - size_in_bytes = len(encoded_data) - logger.debug("Size of encoded NixlAgentMetadata: %s bytes", str(size_in_bytes)) - - # Listen for new requests for metadata. - host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST - path = make_zmq_path("tcp", host, base_port + tp_rank) - logger.debug("Starting listening on path: %s", path) - with zmq_ctx(zmq.ROUTER, path) as sock: - ready_event.set() - poller = zmq.Poller() - poller.register(sock, zmq.POLLIN) - while not stop_event.is_set(): - events = dict( - poller.poll(timeout=NixlConnectorWorker._POLL_TIMEOUT * 1000) - ) - if sock not in events: - continue - identity, _, msg = sock.recv_multipart() - if msg != GET_META_MSG: - logger.warning("Connection listener got unexpected message %s", msg) - sock.send_multipart((identity, b"", encoded_data)) - def _nixl_handshake( self, host: str, @@ -844,16 +908,17 @@ class NixlConnectorWorker: # Handshake only with the remote TP rank that current local rank will # pull from. With homogeneous TP it happens to be the same rank_i. p_remote_rank = self.kv_topo.get_target_remote_rank(remote_tp_size) - path = make_zmq_path("tcp", host, port + p_remote_rank) + path = make_zmq_path("tcp", host, port) logger.debug( - "Querying metadata on path: %s at remote rank %s", path, p_remote_rank + "Querying metadata on path: %s at remote tp rank %s", path, p_remote_rank ) # Send query for the request. with zmq_ctx(zmq.REQ, path) as sock: + msg = msgspec.msgpack.encode((GET_META_MSG, p_remote_rank)) # Set receive timeout to 5 seconds to avoid hanging on dead server sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds - sock.send(GET_META_MSG) + sock.send(msg) metadata_bytes = sock.recv() decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) metadata = decoder.decode(metadata_bytes) @@ -1042,6 +1107,10 @@ class NixlConnectorWorker: assert tensor_size_bytes == curr_tensor_size_bytes, ( "All kv cache tensors must have the same size" ) + # Need to make sure the device ID is non-negative for NIXL, + # Torch uses -1 to indicate CPU tensors while NIXL uses explicit + # memory type. + self.device_id = max(cache.get_device(), 0) caches_data.append( (base_addr, curr_tensor_size_bytes, self.device_id, "") ) @@ -1139,10 +1208,11 @@ class NixlConnectorWorker: assert len(self.block_window_per_layer) == self.num_layers # After KV Caches registered, listen for new connections. - metadata = NixlAgentMetadata( + self.xfer_handshake_metadata = NixlAgentMetadata( engine_id=self.engine_id, agent_metadata=self.nixl_wrapper.get_agent_metadata(), kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], + device_id=self.device_id, num_blocks=self.num_blocks, block_lens=self.block_len_per_layer, attn_backend_name=self.backend_name, @@ -1150,22 +1220,6 @@ class NixlConnectorWorker: if not self.use_host_buffer else self.host_buffer_kv_cache_layout, ) - ready_event, stop_event = threading.Event(), threading.Event() - self._nixl_handshake_listener_t = threading.Thread( - target=self._nixl_handshake_listener, - args=( - metadata, - ready_event, - stop_event, - self.side_channel_port, - self.tp_rank, - ), - daemon=True, - name="nixl_handshake_listener", - ) - self._nixl_handshake_listener_t.start() - self._nixl_handshake_listener_stop_event = stop_event - ready_event.wait() # Wait for listener ZMQ socket to be ready. def add_remote_agent( self, @@ -1267,7 +1321,7 @@ class NixlConnectorWorker: # self.block_len == remote_block_len//tp_ratio bytes. addr = base_addr + block_offset + rank_offset # (addr, len, device id) - blocks_data.append((addr, kv_block_len, remote_tp_rank)) + blocks_data.append((addr, kv_block_len, nixl_agent_meta.device_id)) if self._use_flashinfer: # With FlashInfer index V separately to allow head splitting. @@ -1275,7 +1329,9 @@ class NixlConnectorWorker: block_offset = block_id * nixl_agent_meta.block_lens[i] addr = base_addr + block_offset + rank_offset v_addr = addr + nixl_agent_meta.block_lens[i] // 2 - blocks_data.append((v_addr, kv_block_len, remote_tp_rank)) + blocks_data.append( + (v_addr, kv_block_len, nixl_agent_meta.device_id) + ) logger.debug( "Created %s blocks for dst engine %s with remote rank %s and local rank %s", @@ -1843,14 +1899,6 @@ class NixlConnectorWorker: def shutdown(self): """Shutdown the connector worker.""" self._handshake_initiation_executor.shutdown(wait=False) - if self._nixl_handshake_listener_stop_event is not None: - self._nixl_handshake_listener_stop_event.set() - self._nixl_handshake_listener_stop_event = None - if self._nixl_handshake_listener_t is not None: - # Generous timeout to allow the thread to exit - self._nixl_handshake_listener_t.join(timeout=self._POLL_TIMEOUT * 10) - assert not self._nixl_handshake_listener_t.is_alive() - self._nixl_handshake_listener_t = None for handles in self._recving_transfers.values(): for handle, _ in handles: self.nixl_wrapper.release_xfer_handle(handle) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 6cbd986b3cd32..bfe87b718282c 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -163,6 +163,27 @@ class EngineCore: vllm_config, mm_registry ) + # If a KV connector is initialized for scheduler, we want to collect + # handshake metadata from all workers so the connector in the scheduler + # will have the full context + kv_connector = self.scheduler.get_kv_connector() + if kv_connector is not None: + # Collect and store KV connector xfer metadata from workers + # (after KV cache registration) + xfer_handshake_metadata = ( + self.model_executor.get_kv_connector_handshake_metadata() + ) + + if xfer_handshake_metadata: + # xfer_handshake_metadata is list of dicts from workers + # Each dict already has structure {tp_rank: metadata} + # Merge all worker dicts into a single dict + content: dict[int, Any] = {} + for worker_dict in xfer_handshake_metadata: + if worker_dict is not None: + content.update(worker_dict) + kv_connector.set_xfer_handshake_metadata(content) + # Setup batch queue for pipeline parallelism. # Batch queue for scheduled batches. This enables us to asynchronously # schedule and execute batches, and is required by pipeline parallelism @@ -178,7 +199,7 @@ class EngineCore: self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None if ( self.vllm_config.cache_config.enable_prefix_caching - or self.scheduler.get_kv_connector() is not None + or kv_connector is not None ): caching_hash_fn = get_hash_fn_by_name( vllm_config.cache_config.prefix_caching_hash_algo diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 9fe1912c73e39..ef7840e1796f7 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -9,6 +9,9 @@ from typing import TYPE_CHECKING, Literal, TypeVar, overload from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorHandshakeMetadata, +) from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.tasks import SupportedTask @@ -177,6 +180,11 @@ class Executor(ABC): ): raise NotImplementedError + def get_kv_connector_handshake_metadata( + self, + ) -> list[dict[int, KVConnectorHandshakeMetadata]]: + return self.collective_rpc("get_kv_connector_handshake_metadata") + @overload def execute_model( self, diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 54c5f81fc7e8e..5b11bdf5282fa 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -19,7 +19,11 @@ from vllm.distributed import ( init_distributed_environment, set_custom_all_reduce, ) -from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized +from vllm.distributed.kv_transfer import ( + ensure_kv_transfer_initialized, + get_kv_transfer_group, + has_kv_transfer_group, +) from vllm.distributed.parallel_state import ( get_pp_group, get_tp_group, @@ -348,6 +352,21 @@ class Worker(WorkerBase): return int(self.available_kv_cache_memory_bytes) + def get_kv_connector_handshake_metadata(self) -> dict | None: + """Get KV connector metadata from this worker if available.""" + + if not has_kv_transfer_group(): + return None + + connector = get_kv_transfer_group() + # Return None for connectors that don't need to exchange handshake + # metadata across workers. + if (metadata := connector.get_handshake_metadata()) is None: + return None + + tp_rank = get_tp_group().rank_in_group + return {tp_rank: metadata} + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: return self.model_runner.get_kv_cache_spec()