diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index c6739832355f..3860d7c85724 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -419,6 +419,52 @@ class TestNixlHandshake: return raise TimeoutError("Took too long to complete async handshake.") + @patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper) + def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init): + """ + Verify that adding a remote agent fails if kv_cache_layout differs. + This test is only relevant for heterogeneous TP. + """ + vllm_config = create_vllm_config() + + # Mock TP world size to 2 to force heterogeneous TP when + # remote_tp_size=1 + with patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.get_tensor_model_parallel_world_size", # noqa: E501 + return_value=2): + # Initialize connector and worker (with fake NIXL wrapper) + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, connector.engine_id, hand_shake_latency=0) + worker = connector.connector_worker + + # Minimal local registration params used by add_remote_agent + worker.slot_size_bytes = 4096 + worker.block_len = worker.slot_size_bytes * worker.block_size + worker.num_blocks = 1 + worker.dst_num_blocks[worker.engine_id] = worker.num_blocks + + # Metadata with different kv_cache_layout than local worker + mismatched_layout = "HND" if worker.kv_cache_layout != "HND" \ + else "NHD" + meta = NixlAgentMetadata( + engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + agent_metadata=FakeNixlWrapper.AGENT_METADATA, + kv_caches_base_addr=[0], + num_blocks=1, + block_len=worker.block_len, + attn_backend_name=worker.backend_name, + kv_cache_layout=mismatched_layout, + ) + + # We don't check layout for homogeneous TP and MLA for now, as the + # whole block is moved. + worker.add_remote_agent(meta, remote_tp_size=2) + with pytest.raises(AssertionError): + worker.add_remote_agent(meta, remote_tp_size=1) + # NOTE: resource cleanup in mp backend is a bit finicky, so the order in which # we put here is important. First run ray, it will clean up the resources, then diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index a6eeb278532e..4f51229ffbd2 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -30,6 +30,7 @@ from vllm.forward_context import ForwardContext from vllm.logger import init_logger from vllm.platforms import _Backend, current_platform from vllm.utils import make_zmq_path, make_zmq_socket +from vllm.v1.attention.backends.utils import get_kv_cache_layout from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.request import RequestStatus @@ -73,6 +74,7 @@ class NixlAgentMetadata( num_blocks: int block_len: int attn_backend_name: str + kv_cache_layout: str @dataclass @@ -538,7 +540,9 @@ class NixlConnectorWorker: attn_backend = backend_name_to_enum(self.backend_name) self._use_flashinfer = attn_backend == _Backend.FLASHINFER_VLLM_V1 self._use_pallas_v1 = attn_backend == _Backend.PALLAS_VLLM_V1 + self.kv_cache_layout = get_kv_cache_layout() logger.debug("Detected attention backend %s", self.backend_name) + logger.debug("Detected kv cache layout %s", self.kv_cache_layout) self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size} # With heterogeneous TP, P must wait for all assigned D TP workers to @@ -839,7 +843,8 @@ class NixlConnectorWorker: kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], num_blocks=self.num_blocks, block_len=self.block_len, - attn_backend_name=self.backend_name) + attn_backend_name=self.backend_name, + kv_cache_layout=self.kv_cache_layout) ready_event = threading.Event() self._nixl_handshake_listener_t = threading.Thread( target=self._nixl_handshake_listener, @@ -900,8 +905,7 @@ class NixlConnectorWorker: self._tp_size[engine_id] = remote_tp_size else: assert self._tp_size[engine_id] == remote_tp_size - # We may eventually enable this after asserting equality in cache - # layout and close outputs. + # TODO We may eventually want to skip enforcing the same attn backend. assert nixl_agent_meta.attn_backend_name == self.backend_name remote_agent_name = self.nixl_wrapper.add_remote_agent( @@ -930,6 +934,9 @@ class NixlConnectorWorker: if self._use_flashinfer: # Account for joint KV in FlashInfer. remote_block_size //= 2 + if tp_ratio > 1: + # Heterogeneous TP expects same kv_cache_layout. + assert nixl_agent_meta.kv_cache_layout == self.kv_cache_layout assert nixl_agent_meta.block_len == self.block_len * tp_ratio, ( "Remote P worker KV layer cache must be of shape [2, N, "