diff --git a/docs/features/nixl_connector_usage.md b/docs/features/nixl_connector_usage.md index 795b0c77d610e..bfc0e0d86c6ae 100644 --- a/docs/features/nixl_connector_usage.md +++ b/docs/features/nixl_connector_usage.md @@ -156,6 +156,16 @@ python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py \ NixlConnector currently does not distinguish `kv_role`; the actual prefiller/decoder roles are determined by the upper-level proxy (e.g., `toy_proxy_server.py` using `--prefiller-hosts` and `--decoder-hosts`). Therefore, `kv_role` in `--kv-transfer-config` is effectively a placeholder and does not affect NixlConnector's behavior. +## Experimental Feature + +### Heterogenuous KV Layout support + +Support use case: Prefill with 'HND' and decode with 'NHD' with experimental configuration + +```bash +--kv-transfer-config '{..., "enable_permute_local_kv":"True"}' +``` + ## Example Scripts/Code Refer to these example scripts in the vLLM repository: diff --git a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh index 3bf722900df37..ed6154462bb2b 100755 --- a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh @@ -19,11 +19,18 @@ done echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE" +DECODER_KV_LAYOUT=${DECODER_KV_LAYOUT:-"HND"} # Default to HND, optional NHD +if [[ "$DECODER_KV_LAYOUT" == "NHD" ]]; then + KV_CONFIG_HETERO_LAYOUT=',"enable_permute_local_kv":"True"' +else + KV_CONFIG_HETERO_LAYOUT='' +fi + # Build the kv-transfer-config once if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then - KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"}' + KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${KV_CONFIG_HETERO_LAYOUT}'}' else - KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\"}" + KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}"}" fi # Models to run @@ -117,6 +124,7 @@ run_tests_for_model() { # Build the command with or without model-specific args BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID \ + VLLM_KV_CACHE_LAYOUT='HND' \ UCX_NET_DEVICES=all \ VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \ vllm serve $model_name \ @@ -157,6 +165,7 @@ run_tests_for_model() { # Build the command with or without model-specific args BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID \ + VLLM_KV_CACHE_LAYOUT=$DECODER_KV_LAYOUT \ UCX_NET_DEVICES=all \ VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \ vllm serve $model_name \ diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 71f5d4b2b0fd9..a911ddc56b023 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -286,9 +286,12 @@ def test_prompt_less_than_block_size(): class FakeNixlConnectorWorker(NixlConnectorWorker): REMOTE_ENGINE_ID = "remote_engine" - def __init__(self, *args, hand_shake_latency: float = 1.8, **kwargs): + def __init__( + self, *args, hand_shake_latency: float = 1.8, kv_cache_layout="HND", **kwargs + ): super().__init__(*args, **kwargs) self._hand_shake_latency = hand_shake_latency + self.kv_cache_layout = kv_cache_layout def _nixl_handshake( self, host: str, port: int, remote_tp_size: int, expected_engine_id: str @@ -564,10 +567,63 @@ class TestNixlHandshake: # We don't check layout for homogeneous TP and MLA for now, as the # whole block is moved. - worker.add_remote_agent(meta, remote_tp_size=2) + with pytest.raises(RuntimeError): + # mismatched layout is expected to fail + worker.add_remote_agent(meta, remote_tp_size=2) with pytest.raises(AssertionError): worker.add_remote_agent(meta, remote_tp_size=1) + @patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper, + ) + def test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental( + self, dist_init + ): + """ + Verify that adding a remote agent fails if kv_cache_layout differs. + This test is only relevant for heterogeneous TP. + """ + vllm_config = create_vllm_config(enable_permute_local_kv=True) + + # Mock TP world size to 2 to force heterogeneous TP when + # remote_tp_size=1 + with patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.get_tensor_model_parallel_world_size", # noqa: E501 + return_value=2, + ): + # Initialize connector and worker (with fake NIXL wrapper) + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, + connector.engine_id, + hand_shake_latency=0, + kv_cache_layout="NHD", + ) + worker = connector.connector_worker + + # Minimal local registration params used by add_remote_agent + worker.slot_size_per_layer = [2048] + worker.block_len_per_layer = [2048 * worker.block_size] + worker.num_blocks = 1 + worker.dst_num_blocks[worker.engine_id] = worker.num_blocks + + # Metadata with different kv_cache_layout than local worker + meta = NixlAgentMetadata( + engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + agent_metadata=FakeNixlWrapper.AGENT_METADATA, + kv_caches_base_addr=[0], + num_blocks=1, + # prefill TP=1, decode TP=2, remote block_lens is double to local + block_lens=[i * 2 for i in worker.block_len_per_layer], + attn_backend_name=worker.backend_name, + kv_cache_layout="HND", + ) + + # We don't check layout for homogeneous TP and MLA for now, as the + # whole block is moved. + worker.add_remote_agent(meta, remote_tp_size=1) + # NOTE: resource cleanup in mp backend is a bit finicky, so the order in which # we put here is important. First run ray, it will clean up the resources, then diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index b07fd0536a436..e7f505d55e7a4 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -83,6 +83,7 @@ def create_vllm_config( block_size: int = 16, max_model_len: int = 10000, enable_chunked_prefill: bool = True, + enable_permute_local_kv: bool = False, ) -> VllmConfig: """Initialize VllmConfig For Testing.""" scheduler_config = SchedulerConfig( @@ -108,6 +109,7 @@ def create_vllm_config( kv_transfer_config = KVTransferConfig( kv_connector="NixlConnector", kv_role="kv_both", + enable_permute_local_kv=enable_permute_local_kv, ) return VllmConfig( scheduler_config=scheduler_config, diff --git a/vllm/config/kv_transfer.py b/vllm/config/kv_transfer.py index d7a9d5808319e..eafd0e015a88d 100644 --- a/vllm/config/kv_transfer.py +++ b/vllm/config/kv_transfer.py @@ -61,6 +61,9 @@ class KVTransferConfig: """The Python module path to dynamically load the KV connector from. Only supported in V1.""" + enable_permute_local_kv: bool = False + """Experiment feature flag to enable HND to NHD KV Transfer""" + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 8c4c82f76ff29..490f209373db3 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -563,6 +563,7 @@ class NixlConnectorWorker: self.world_size = get_tensor_model_parallel_world_size() self.tp_group = get_tp_group() self.num_blocks = 0 + self.enable_permute_local_kv = False # KV Caches and nixl tracking data. self.device_type = current_platform.device_type @@ -1094,6 +1095,23 @@ class NixlConnectorWorker: is_kv_replicated = self._tp_size[engine_id] // total_num_kv_heads >= 1 remote_block_len = nixl_agent_meta.block_lens[0] + if nixl_agent_meta.kv_cache_layout != self.kv_cache_layout: + if ( + self.vllm_config.kv_transfer_config is not None + and self.vllm_config.kv_transfer_config.enable_permute_local_kv + and nixl_agent_meta.kv_cache_layout == "HND" + ): + logger.info( + "Remote is HND and local is NHD, enabled additional permute " + "on local device KV." + ) + self.enable_permute_local_kv = True + else: + raise RuntimeError( + "Heterogeneous TP expects same kv_cache_layout. " + "Or enable experimental feature to use HND to NHD support by " + "setting 'enable_permute_local_kv'=True in --kv-transfer-config." + ) if self.use_mla or is_kv_replicated: # With replicated KV cache, only the number of blocks can differ. assert self.block_len_per_layer == nixl_agent_meta.block_lens, ( @@ -1114,7 +1132,10 @@ class NixlConnectorWorker: remote_block_size //= 2 if tp_ratio > 1: # Heterogeneous TP expects same kv_cache_layout. - assert nixl_agent_meta.kv_cache_layout == self.kv_cache_layout + if nixl_agent_meta.kv_cache_layout == "NHD": + raise ValueError( + "Heterogeneous TP is not supported for remote with NHD." + ) if self.device_type == "xpu": raise ValueError("Heterogeneous TP is not supported on XPU") @@ -1226,6 +1247,41 @@ class NixlConnectorWorker: "d2h", ) + def permute_device_kv(self, block_ids: list[int]): + """Transforms the layout of received KV cache blocks to the local format. + + This method corrects layout mismatches from direct memory copies by + permuting the tensor dimensions. + + - **Source Layout:** `[num_blocks, n_kv_head, block_size, head_dim]` + - **Target Layout:** `[num_blocks, block_size, n_kv_head, head_dim]` + + Args: + block_ids: A list of block IDs to update and permute. + + Implementation: + - x = blocks_to_update.reshape(src_shape) # view local kv with sender layout + - permuted_blocks = x.permute(*inv_order) # transpose n_kv_heads, block_size + - cache.index_copy_(0, indices, permuted_blocks) # copy permuted kv back + + """ + split_k_and_v = not (self.use_mla or self._use_pallas or self._use_flashinfer) + inv_order = [0, 2, 1, 3] + sample_cache = list(self.device_kv_caches.values())[0][0] + target_shape = list(sample_cache.shape) + target_shape[0] = -1 + src_shape = tuple(target_shape[i] for i in inv_order) + indices = torch.tensor(block_ids, device=sample_cache.device) + + for _, cache_or_caches in self.device_kv_caches.items(): + cache_list = cache_or_caches if split_k_and_v else [cache_or_caches] + for cache in cache_list: + blocks_to_update = cache.index_select(0, indices) + permuted_blocks = blocks_to_update.reshape(src_shape).permute( + *inv_order + ) + cache.index_copy_(0, indices, permuted_blocks) + def get_finished(self) -> tuple[set[str], set[str]]: """ Get requests that are done sending or recving on this specific worker. @@ -1273,6 +1329,15 @@ class NixlConnectorWorker: del self._reqs_to_send[req_id] done_sending.add(req_id) + if self.enable_permute_local_kv and len(done_recving) > 0: + block_ids = [] + for req_id in done_recving: + meta = self._recving_metadata.pop(req_id) + assert meta, f"{req_id} not found in recving_metadata list" + block_ids += meta.local_block_ids + + self.permute_device_kv(block_ids) + return done_sending, done_recving def _get_new_notifs(self) -> set[str]: