[NIXL][HeteroTP] Enable KV transfer from HND prefill to NHD decode (#26556)

Signed-off-by: Chendi Xue <chendi.xue@intel.com>
This commit is contained in:
Chendi.Xue 2025-10-14 04:46:05 -05:00 committed by GitHub
parent 74704d4553
commit 7e6edb1469
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 150 additions and 5 deletions

View File

@ -156,6 +156,16 @@ python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py \
NixlConnector currently does not distinguish `kv_role`; the actual prefiller/decoder roles are determined by the upper-level proxy (e.g., `toy_proxy_server.py` using `--prefiller-hosts` and `--decoder-hosts`).
Therefore, `kv_role` in `--kv-transfer-config` is effectively a placeholder and does not affect NixlConnector's behavior.
## Experimental Feature
### Heterogenuous KV Layout support
Support use case: Prefill with 'HND' and decode with 'NHD' with experimental configuration
```bash
--kv-transfer-config '{..., "enable_permute_local_kv":"True"}'
```
## Example Scripts/Code
Refer to these example scripts in the vLLM repository:

View File

@ -19,11 +19,18 @@ done
echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE"
DECODER_KV_LAYOUT=${DECODER_KV_LAYOUT:-"HND"} # Default to HND, optional NHD
if [[ "$DECODER_KV_LAYOUT" == "NHD" ]]; then
KV_CONFIG_HETERO_LAYOUT=',"enable_permute_local_kv":"True"'
else
KV_CONFIG_HETERO_LAYOUT=''
fi
# Build the kv-transfer-config once
if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then
KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${KV_CONFIG_HETERO_LAYOUT}'}'
else
KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\"}"
KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}"}"
fi
# Models to run
@ -117,6 +124,7 @@ run_tests_for_model() {
# Build the command with or without model-specific args
BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID \
VLLM_KV_CACHE_LAYOUT='HND' \
UCX_NET_DEVICES=all \
VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \
vllm serve $model_name \
@ -157,6 +165,7 @@ run_tests_for_model() {
# Build the command with or without model-specific args
BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID \
VLLM_KV_CACHE_LAYOUT=$DECODER_KV_LAYOUT \
UCX_NET_DEVICES=all \
VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \
vllm serve $model_name \

View File

@ -286,9 +286,12 @@ def test_prompt_less_than_block_size():
class FakeNixlConnectorWorker(NixlConnectorWorker):
REMOTE_ENGINE_ID = "remote_engine"
def __init__(self, *args, hand_shake_latency: float = 1.8, **kwargs):
def __init__(
self, *args, hand_shake_latency: float = 1.8, kv_cache_layout="HND", **kwargs
):
super().__init__(*args, **kwargs)
self._hand_shake_latency = hand_shake_latency
self.kv_cache_layout = kv_cache_layout
def _nixl_handshake(
self, host: str, port: int, remote_tp_size: int, expected_engine_id: str
@ -564,10 +567,63 @@ class TestNixlHandshake:
# We don't check layout for homogeneous TP and MLA for now, as the
# whole block is moved.
worker.add_remote_agent(meta, remote_tp_size=2)
with pytest.raises(RuntimeError):
# mismatched layout is expected to fail
worker.add_remote_agent(meta, remote_tp_size=2)
with pytest.raises(AssertionError):
worker.add_remote_agent(meta, remote_tp_size=1)
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
def test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental(
self, dist_init
):
"""
Verify that adding a remote agent fails if kv_cache_layout differs.
This test is only relevant for heterogeneous TP.
"""
vllm_config = create_vllm_config(enable_permute_local_kv=True)
# Mock TP world size to 2 to force heterogeneous TP when
# remote_tp_size=1
with patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.get_tensor_model_parallel_world_size", # noqa: E501
return_value=2,
):
# Initialize connector and worker (with fake NIXL wrapper)
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config,
connector.engine_id,
hand_shake_latency=0,
kv_cache_layout="NHD",
)
worker = connector.connector_worker
# Minimal local registration params used by add_remote_agent
worker.slot_size_per_layer = [2048]
worker.block_len_per_layer = [2048 * worker.block_size]
worker.num_blocks = 1
worker.dst_num_blocks[worker.engine_id] = worker.num_blocks
# Metadata with different kv_cache_layout than local worker
meta = NixlAgentMetadata(
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
num_blocks=1,
# prefill TP=1, decode TP=2, remote block_lens is double to local
block_lens=[i * 2 for i in worker.block_len_per_layer],
attn_backend_name=worker.backend_name,
kv_cache_layout="HND",
)
# We don't check layout for homogeneous TP and MLA for now, as the
# whole block is moved.
worker.add_remote_agent(meta, remote_tp_size=1)
# NOTE: resource cleanup in mp backend is a bit finicky, so the order in which
# we put here is important. First run ray, it will clean up the resources, then

View File

@ -83,6 +83,7 @@ def create_vllm_config(
block_size: int = 16,
max_model_len: int = 10000,
enable_chunked_prefill: bool = True,
enable_permute_local_kv: bool = False,
) -> VllmConfig:
"""Initialize VllmConfig For Testing."""
scheduler_config = SchedulerConfig(
@ -108,6 +109,7 @@ def create_vllm_config(
kv_transfer_config = KVTransferConfig(
kv_connector="NixlConnector",
kv_role="kv_both",
enable_permute_local_kv=enable_permute_local_kv,
)
return VllmConfig(
scheduler_config=scheduler_config,

View File

@ -61,6 +61,9 @@ class KVTransferConfig:
"""The Python module path to dynamically load the KV connector from.
Only supported in V1."""
enable_permute_local_kv: bool = False
"""Experiment feature flag to enable HND to NHD KV Transfer"""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,

View File

@ -563,6 +563,7 @@ class NixlConnectorWorker:
self.world_size = get_tensor_model_parallel_world_size()
self.tp_group = get_tp_group()
self.num_blocks = 0
self.enable_permute_local_kv = False
# KV Caches and nixl tracking data.
self.device_type = current_platform.device_type
@ -1094,6 +1095,23 @@ class NixlConnectorWorker:
is_kv_replicated = self._tp_size[engine_id] // total_num_kv_heads >= 1
remote_block_len = nixl_agent_meta.block_lens[0]
if nixl_agent_meta.kv_cache_layout != self.kv_cache_layout:
if (
self.vllm_config.kv_transfer_config is not None
and self.vllm_config.kv_transfer_config.enable_permute_local_kv
and nixl_agent_meta.kv_cache_layout == "HND"
):
logger.info(
"Remote is HND and local is NHD, enabled additional permute "
"on local device KV."
)
self.enable_permute_local_kv = True
else:
raise RuntimeError(
"Heterogeneous TP expects same kv_cache_layout. "
"Or enable experimental feature to use HND to NHD support by "
"setting 'enable_permute_local_kv'=True in --kv-transfer-config."
)
if self.use_mla or is_kv_replicated:
# With replicated KV cache, only the number of blocks can differ.
assert self.block_len_per_layer == nixl_agent_meta.block_lens, (
@ -1114,7 +1132,10 @@ class NixlConnectorWorker:
remote_block_size //= 2
if tp_ratio > 1:
# Heterogeneous TP expects same kv_cache_layout.
assert nixl_agent_meta.kv_cache_layout == self.kv_cache_layout
if nixl_agent_meta.kv_cache_layout == "NHD":
raise ValueError(
"Heterogeneous TP is not supported for remote with NHD."
)
if self.device_type == "xpu":
raise ValueError("Heterogeneous TP is not supported on XPU")
@ -1226,6 +1247,41 @@ class NixlConnectorWorker:
"d2h",
)
def permute_device_kv(self, block_ids: list[int]):
"""Transforms the layout of received KV cache blocks to the local format.
This method corrects layout mismatches from direct memory copies by
permuting the tensor dimensions.
- **Source Layout:** `[num_blocks, n_kv_head, block_size, head_dim]`
- **Target Layout:** `[num_blocks, block_size, n_kv_head, head_dim]`
Args:
block_ids: A list of block IDs to update and permute.
Implementation:
- x = blocks_to_update.reshape(src_shape) # view local kv with sender layout
- permuted_blocks = x.permute(*inv_order) # transpose n_kv_heads, block_size
- cache.index_copy_(0, indices, permuted_blocks) # copy permuted kv back
"""
split_k_and_v = not (self.use_mla or self._use_pallas or self._use_flashinfer)
inv_order = [0, 2, 1, 3]
sample_cache = list(self.device_kv_caches.values())[0][0]
target_shape = list(sample_cache.shape)
target_shape[0] = -1
src_shape = tuple(target_shape[i] for i in inv_order)
indices = torch.tensor(block_ids, device=sample_cache.device)
for _, cache_or_caches in self.device_kv_caches.items():
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
for cache in cache_list:
blocks_to_update = cache.index_select(0, indices)
permuted_blocks = blocks_to_update.reshape(src_shape).permute(
*inv_order
)
cache.index_copy_(0, indices, permuted_blocks)
def get_finished(self) -> tuple[set[str], set[str]]:
"""
Get requests that are done sending or recving on this specific worker.
@ -1273,6 +1329,15 @@ class NixlConnectorWorker:
del self._reqs_to_send[req_id]
done_sending.add(req_id)
if self.enable_permute_local_kv and len(done_recving) > 0:
block_ids = []
for req_id in done_recving:
meta = self._recving_metadata.pop(req_id)
assert meta, f"{req_id} not found in recving_metadata list"
block_ids += meta.local_block_ids
self.permute_device_kv(block_ids)
return done_sending, done_recving
def _get_new_notifs(self) -> set[str]: