From 6601c9c5bec80ad1b8474a08f98f5f77c520687b Mon Sep 17 00:00:00 2001 From: NickLucche Date: Fri, 10 Oct 2025 10:30:38 +0000 Subject: [PATCH] add and update tests Signed-off-by: NickLucche --- .../kv_connector/unit/test_nixl_connector.py | 134 +++++++++++++++--- .../kv_connector/v1/nixl_connector.py | 13 +- 2 files changed, 123 insertions(+), 24 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index a1f53cb255630..cff1f08845721 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -291,6 +291,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker): def __init__(self, *args, hand_shake_latency: float = 1.8, **kwargs): super().__init__(*args, **kwargs) self._hand_shake_latency = hand_shake_latency + self.kv_cache_layout = "HND" def _nixl_handshake( self, host: str, port: int, remote_tp_size: int, expected_engine_id: str @@ -307,21 +308,42 @@ class FakeNixlConnectorWorker(NixlConnectorWorker): assert expected_engine_id == self.REMOTE_ENGINE_ID - remote_agent_name = self.add_remote_agent( - NixlAgentMetadata( - engine_id=self.REMOTE_ENGINE_ID, - agent_metadata=FakeNixlWrapper.AGENT_METADATA, - kv_caches_base_addr=[0], - num_blocks=1, - block_lens=self.block_len_per_layer, - attn_backend_name=self.backend_name, - # `self.kv_cache_layout` is only forced to HND when vllm engine - # is started. We mock HND here. - kv_cache_layout="HND", - ), - remote_tp_size=remote_tp_size, - ) - return {0: remote_agent_name} + # Adjust remote block length metadata to satisfy heterogeneous TP + # invariants enforced during handshake validation. + remote_block_lens = list(self.block_len_per_layer) + tp_ratio = self.kv_info.tp_ratio(remote_tp_size=remote_tp_size) + if remote_tp_size > self.world_size: + # P TP > D TP case, block_len of remote is smaller + remote_block_lens = [ + block_len // (-tp_ratio) for block_len in remote_block_lens + ] + elif remote_tp_size < self.world_size: + remote_block_lens = [ + block_len * tp_ratio for block_len in remote_block_lens + ] + + # When remote tp_size > local tp_size, handshake with multiple + # remote ranks. + num_hanshakes = 1 if tp_ratio > 0 else -tp_ratio + remote_agents: dict[int, str] = {} + for remote_tp_rank in range(num_hanshakes): + remote_agent_name = self.add_remote_agent( + NixlAgentMetadata( + engine_id=self.REMOTE_ENGINE_ID, + agent_metadata=FakeNixlWrapper.AGENT_METADATA, + kv_caches_base_addr=[0], + num_blocks=1, + block_lens=remote_block_lens, + attn_backend_name=self.backend_name, + # `self.kv_cache_layout` is only forced to HND when vllm engine + # is started. We mock HND here. + kv_cache_layout="HND", + ), + remote_tp_rank=remote_tp_rank, + remote_tp_size=remote_tp_size, + ) + remote_agents[remote_tp_rank] = remote_agent_name + return remote_agents class TestNixlHandshake: @@ -352,7 +374,13 @@ class TestNixlHandshake: vllm_config, connector.engine_id, hand_shake_latency=0 ) assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper) - connector.connector_worker.nixl_wrapper.set_cycles_before_xfer_done(3) + worker = connector.connector_worker + worker.nixl_wrapper.set_cycles_before_xfer_done(3) + # simulate handshake + worker.dst_xfer_side_handles = { + FakeNixlConnectorWorker.REMOTE_ENGINE_ID: {0: 1} + } + worker.kv_cache_layout = "HND" num_xfers = 4 while True: # For the same request_id, initiate multiple xfers across different @@ -464,6 +492,70 @@ class TestNixlHandshake: return raise TimeoutError("Took too long to complete async handshake.") + @patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper, + ) + @pytest.mark.parametrize("local_tp_size", [1, 2]) + def test_prefill_tp_size_greater_than_decode_tp_size( + self, local_tp_size: int, dist_init + ): + """ + Verify remote TP > local TP handshake succeeds with different + remote configurations. + """ + + vllm_config = create_vllm_config() + local_tp_size = 1 + vllm_config.parallel_config.tensor_parallel_size = local_tp_size + + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, connector.engine_id, hand_shake_latency=0 + ) + worker = connector.connector_worker + + # Minimal local registration params used by add_remote_agent + worker.slot_size_per_layer = [4096] + worker.block_len_per_layer = [4096 * worker.block_size] + worker.num_blocks = 1 + worker.dst_num_blocks[worker.engine_id] = worker.num_blocks + worker.src_blocks_data = [(0, worker.block_len_per_layer[0], worker.tp_rank)] + + def check_handshake(remote_tp_size: int): + tp_ratio = remote_tp_size // local_tp_size + assert set(remote_agents.keys()) == set(range(tp_ratio)) + + remote_engine_id = worker.REMOTE_ENGINE_ID + assert worker._tp_size[remote_engine_id] == remote_tp_size + assert -tp_ratio == worker.kv_info.tp_ratio(remote_engine_id) + # ensure src_xfer_side_chunked_handles is populated with tpratio chunks + assert -tp_ratio in worker.src_xfer_side_chunked_handles + assert len(worker.src_xfer_side_chunked_handles[-tp_ratio]) == tp_ratio + assert remote_engine_id in worker.dst_xfer_side_handles + assert set(worker.dst_xfer_side_handles[remote_engine_id].keys()) == set( + range(tp_ratio) + ) + + remote_agents = worker._nixl_handshake( + host="localhost", + port=1234, + remote_tp_size=2, + expected_engine_id=worker.REMOTE_ENGINE_ID, + ) + check_handshake(2) + + # NOTE flexiblity: a second remote with higher number of ranks + # is discovered + worker.REMOTE_ENGINE_ID = "remote_engine_2" + remote_agents = worker._nixl_handshake( + host="localhost", + port=1234, + remote_tp_size=6, + expected_engine_id=worker.REMOTE_ENGINE_ID, + ) + check_handshake(6) + @patch( "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", FakeNixlWrapper, @@ -564,10 +656,9 @@ class TestNixlHandshake: kv_cache_layout=mismatched_layout, ) - # We don't check layout for homogeneous TP and MLA for now, as the - # whole block is moved. - worker.add_remote_agent(meta, remote_tp_size=2) + # Layout check done for both homogeneous and heterogeneous TP. with pytest.raises(AssertionError): + worker.add_remote_agent(meta, remote_tp_size=2) worker.add_remote_agent(meta, remote_tp_size=1) @@ -1057,7 +1148,8 @@ def test_shutdown_cleans_up_resources(dist_init): ): worker._recving_transfers = {"req1": [(123, time.perf_counter())]} worker.src_xfer_side_handle = 456 - worker.dst_xfer_side_handles = {"engine1": 789} + worker.src_xfer_side_chunked_handles = {-2: [456]} + worker.dst_xfer_side_handles = {"engine1": {0: 789}} worker._remote_agents = {"engine1": {0: "agent1"}} worker._registered_descs = ["desc1", "desc2"] @@ -1071,7 +1163,7 @@ def test_shutdown_cleans_up_resources(dist_init): mock_listener.join.assert_called_once_with(timeout=0) mock_rel_xfer.assert_called_once_with(123) - assert mock_rel_dlist.call_count == 2 + assert mock_rel_dlist.call_count == 3 mock_rel_dlist.assert_any_call(456) # src handle mock_rel_dlist.assert_any_call(789) # dst handle mock_rem_agent.assert_called_once_with("agent1") diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 66d692e7387da..0f71addb67191 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1171,7 +1171,14 @@ class NixlConnectorWorker: # Handle tp_size>num_kv_heads: replicate KV cache. indexes_into_remote = ( - not self.kv_info.replicates_kv_cache(engine_id) and tp_ratio < 0 + not self.kv_info.replicates_kv_cache(engine_id) and tp_ratio > 0 + ) + + logger.debug( + "Registering remote agent (%s, rank %s) memory regions with tp_ratio %s", + engine_id, + remote_tp_rank, + tp_ratio, ) ### (Optional) Register local agent memory regions @@ -1724,8 +1731,8 @@ class NixlConnectorWorker: if self._nixl_handshake_listener_t is not None: self._nixl_handshake_listener_t.join(timeout=0) self._nixl_handshake_listener_t = None - for handles in self._recving_transfers.values(): - for handle, _ in handles: + for rcv_handles in self._recving_transfers.values(): + for handle, _ in rcv_handles: self.nixl_wrapper.release_xfer_handle(handle) self._recving_transfers.clear() if self.src_xfer_side_handle: