diff --git a/tests/v1/kv_connector/unit/test_multi_connector.py b/tests/v1/kv_connector/unit/test_multi_connector.py index 72848c1a706e7..e82691cd05e25 100644 --- a/tests/v1/kv_connector/unit/test_multi_connector.py +++ b/tests/v1/kv_connector/unit/test_multi_connector.py @@ -76,6 +76,9 @@ class TestSharedStorageConnector(SharedStorageConnector): return attr +# This relies on "fork" multiprocessing method being used. +# It's the default but vLLM may fall back to spawn if for example CUDA +# is already initialized. KVConnectorFactory.register_connector("TestSharedStorageConnector", TestSharedStorageConnector.__module__, TestSharedStorageConnector.__name__) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index e18c4975a3221..c4f558b7acdb0 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -166,8 +166,8 @@ class FakeNixlConnectorWorker(NixlConnectorWorker): super().__init__(*args, **kwargs) self._hand_shake_latency = hand_shake_latency - def _nixl_handshake(self, host: str, port: int, - remote_tp_size: int) -> dict[int, str]: + def _nixl_handshake(self, host: str, port: int, remote_tp_size: int, + expected_engine_id: str) -> dict[int, str]: # Mimic slow _nixl_handshake, as well as bypass zmq communication. time.sleep(self._hand_shake_latency) # These should've been done in register_kv_caches(), called by @@ -177,6 +177,8 @@ class FakeNixlConnectorWorker(NixlConnectorWorker): self.num_blocks = 1 self.dst_num_blocks[self.engine_id] = self.num_blocks + assert expected_engine_id == self.REMOTE_ENGINE_ID + remote_agent_name = self.add_remote_agent( NixlAgentMetadata( engine_id=self.REMOTE_ENGINE_ID, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index be3c233994199..a2eaa0040191e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -47,7 +47,10 @@ class MultiConnector(KVConnectorBase_V1): assert ktcs is not None for ktc in ktcs: temp_config = copy.copy(vllm_config) - temp_config.kv_transfer_config = KVTransferConfig(**ktc) + engine_id = ktc.get("engine_id", + vllm_config.kv_transfer_config.engine_id) + temp_config.kv_transfer_config = KVTransferConfig( + **ktc, engine_id=engine_id) self._connectors.append( KVConnectorFactory.create_connector_v1(temp_config, role)) @@ -187,7 +190,7 @@ class MultiConnector(KVConnectorBase_V1): async_saves += 1 if txfer_params is not None: if kv_txfer_params is not None: - #TODO we can probably change this to merge the dicts here, + # TODO we can probably change this to merge the dicts here, # checking for key clashes. raise RuntimeError( "Only one connector can produce KV transfer params") diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 67adb3e8a3c90..d2d3e88eabce6 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -488,8 +488,13 @@ class NixlConnectorWorker: "Connection listener got unexpected message %s", msg) sock.send_multipart((identity, b"", encoded_data)) - def _nixl_handshake(self, host: str, port: int, - remote_tp_size: int) -> dict[int, str]: + def _nixl_handshake( + self, + host: str, + port: int, + remote_tp_size: int, + expected_engine_id: str, + ) -> dict[int, str]: """Do a NIXL handshake with a remote instance.""" start_time = time.perf_counter() @@ -498,26 +503,6 @@ class NixlConnectorWorker: # a hack to keep us moving. We will switch when moving to etcd # or where we have a single ZMQ socket in the scheduler. - def handshake(path: str, rank: int) -> str: - # Send query for the request. - with zmq_ctx(zmq.REQ, path) as sock: - sock.send(GET_META_MSG) - metadata_bytes = sock.recv() - decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) - metadata = decoder.decode(metadata_bytes) - got_metadata_time = time.perf_counter() - - # Register Remote agent. - remote_agent_name = self.add_remote_agent( - metadata, rank, remote_tp_size) - setup_agent_time = time.perf_counter() - - logger.debug("NIXL handshake: get metadata took: %s", - got_metadata_time - start_time) - logger.debug("NIXL handshake: add agent took: %s", - setup_agent_time - got_metadata_time) - return remote_agent_name - # Handshake only with the remote TP rank that current local rank will # pull from. With homogeneous TP it happens to be the same rank_i. tp_ratio = self._tp_size[self.engine_id] // remote_tp_size @@ -525,8 +510,32 @@ class NixlConnectorWorker: path = make_zmq_path("tcp", host, port + p_remote_rank) logger.debug("Querying metadata on path: %s at remote rank %s", path, p_remote_rank) + + # Send query for the request. + with zmq_ctx(zmq.REQ, path) as sock: + sock.send(GET_META_MSG) + metadata_bytes = sock.recv() + decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) + metadata = decoder.decode(metadata_bytes) + got_metadata_time = time.perf_counter() + logger.debug("NIXL handshake: get metadata took: %s", + got_metadata_time - start_time) + + # Ensure engine id matches. + if metadata.engine_id != expected_engine_id: + raise RuntimeError(f"Remote NIXL agent engine ID mismatch. " + f"Expected {expected_engine_id}," + f"received {metadata.engine_id}.") + + # Register Remote agent. + remote_agent_name = self.add_remote_agent(metadata, p_remote_rank, + remote_tp_size) + setup_agent_time = time.perf_counter() + logger.debug("NIXL handshake: add agent took: %s", + setup_agent_time - got_metadata_time) + # Remote rank -> agent name. - return {p_remote_rank: handshake(path, p_remote_rank)} + return {p_remote_rank: remote_agent_name} def _background_nixl_handshake(self, req_id: str, remote_engine_id: EngineId, meta: ReqMeta): @@ -535,7 +544,7 @@ class NixlConnectorWorker: if fut is None: fut = self._handshake_initiation_executor.submit( self._nixl_handshake, meta.remote_host, meta.remote_port, - meta.tp_size) + meta.tp_size, remote_engine_id) self._handshake_futures[remote_engine_id] = fut def done_callback(f: Future[dict[int, str]], eid=remote_engine_id): @@ -738,10 +747,10 @@ class NixlConnectorWorker: if remote_tp_rank in self._remote_agents.get(engine_id, {}): return self._remote_agents[engine_id][remote_tp_rank] - if engine_id in self._tp_size: - assert self._tp_size[engine_id] == remote_tp_size - else: + if engine_id not in self._tp_size: self._tp_size[engine_id] = remote_tp_size + else: + assert self._tp_size[engine_id] == remote_tp_size # We may eventually enable this after asserting equality in cache # layout and close outputs. assert nixl_agent_meta.attn_backend_name == self.backend_name