[NIXL][HeteroTP] Enable KV transfer from HND prefill to NHD decode (#26556)

Signed-off-by: Chendi Xue <chendi.xue@intel.com>
This commit is contained in:
Chendi.Xue 2025-10-14 04:46:05 -05:00 committed by GitHub
parent 74704d4553
commit 7e6edb1469
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 150 additions and 5 deletions

View File

@ -156,6 +156,16 @@ python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py \
NixlConnector currently does not distinguish `kv_role`; the actual prefiller/decoder roles are determined by the upper-level proxy (e.g., `toy_proxy_server.py` using `--prefiller-hosts` and `--decoder-hosts`). NixlConnector currently does not distinguish `kv_role`; the actual prefiller/decoder roles are determined by the upper-level proxy (e.g., `toy_proxy_server.py` using `--prefiller-hosts` and `--decoder-hosts`).
Therefore, `kv_role` in `--kv-transfer-config` is effectively a placeholder and does not affect NixlConnector's behavior. Therefore, `kv_role` in `--kv-transfer-config` is effectively a placeholder and does not affect NixlConnector's behavior.
## Experimental Feature
### Heterogenuous KV Layout support
Support use case: Prefill with 'HND' and decode with 'NHD' with experimental configuration
```bash
--kv-transfer-config '{..., "enable_permute_local_kv":"True"}'
```
## Example Scripts/Code ## Example Scripts/Code
Refer to these example scripts in the vLLM repository: Refer to these example scripts in the vLLM repository:

View File

@ -19,11 +19,18 @@ done
echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE" echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE"
DECODER_KV_LAYOUT=${DECODER_KV_LAYOUT:-"HND"} # Default to HND, optional NHD
if [[ "$DECODER_KV_LAYOUT" == "NHD" ]]; then
KV_CONFIG_HETERO_LAYOUT=',"enable_permute_local_kv":"True"'
else
KV_CONFIG_HETERO_LAYOUT=''
fi
# Build the kv-transfer-config once # Build the kv-transfer-config once
if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then
KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"}' KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${KV_CONFIG_HETERO_LAYOUT}'}'
else else
KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\"}" KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}"}"
fi fi
# Models to run # Models to run
@ -117,6 +124,7 @@ run_tests_for_model() {
# Build the command with or without model-specific args # Build the command with or without model-specific args
BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID \ BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID \
VLLM_KV_CACHE_LAYOUT='HND' \
UCX_NET_DEVICES=all \ UCX_NET_DEVICES=all \
VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \ VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \
vllm serve $model_name \ vllm serve $model_name \
@ -157,6 +165,7 @@ run_tests_for_model() {
# Build the command with or without model-specific args # Build the command with or without model-specific args
BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID \ BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID \
VLLM_KV_CACHE_LAYOUT=$DECODER_KV_LAYOUT \
UCX_NET_DEVICES=all \ UCX_NET_DEVICES=all \
VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \ VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \
vllm serve $model_name \ vllm serve $model_name \

View File

@ -286,9 +286,12 @@ def test_prompt_less_than_block_size():
class FakeNixlConnectorWorker(NixlConnectorWorker): class FakeNixlConnectorWorker(NixlConnectorWorker):
REMOTE_ENGINE_ID = "remote_engine" REMOTE_ENGINE_ID = "remote_engine"
def __init__(self, *args, hand_shake_latency: float = 1.8, **kwargs): def __init__(
self, *args, hand_shake_latency: float = 1.8, kv_cache_layout="HND", **kwargs
):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._hand_shake_latency = hand_shake_latency self._hand_shake_latency = hand_shake_latency
self.kv_cache_layout = kv_cache_layout
def _nixl_handshake( def _nixl_handshake(
self, host: str, port: int, remote_tp_size: int, expected_engine_id: str self, host: str, port: int, remote_tp_size: int, expected_engine_id: str
@ -564,10 +567,63 @@ class TestNixlHandshake:
# We don't check layout for homogeneous TP and MLA for now, as the # We don't check layout for homogeneous TP and MLA for now, as the
# whole block is moved. # whole block is moved.
worker.add_remote_agent(meta, remote_tp_size=2) with pytest.raises(RuntimeError):
# mismatched layout is expected to fail
worker.add_remote_agent(meta, remote_tp_size=2)
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
worker.add_remote_agent(meta, remote_tp_size=1) worker.add_remote_agent(meta, remote_tp_size=1)
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
def test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental(
self, dist_init
):
"""
Verify that adding a remote agent fails if kv_cache_layout differs.
This test is only relevant for heterogeneous TP.
"""
vllm_config = create_vllm_config(enable_permute_local_kv=True)
# Mock TP world size to 2 to force heterogeneous TP when
# remote_tp_size=1
with patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.get_tensor_model_parallel_world_size", # noqa: E501
return_value=2,
):
# Initialize connector and worker (with fake NIXL wrapper)
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config,
connector.engine_id,
hand_shake_latency=0,
kv_cache_layout="NHD",
)
worker = connector.connector_worker
# Minimal local registration params used by add_remote_agent
worker.slot_size_per_layer = [2048]
worker.block_len_per_layer = [2048 * worker.block_size]
worker.num_blocks = 1
worker.dst_num_blocks[worker.engine_id] = worker.num_blocks
# Metadata with different kv_cache_layout than local worker
meta = NixlAgentMetadata(
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
num_blocks=1,
# prefill TP=1, decode TP=2, remote block_lens is double to local
block_lens=[i * 2 for i in worker.block_len_per_layer],
attn_backend_name=worker.backend_name,
kv_cache_layout="HND",
)
# We don't check layout for homogeneous TP and MLA for now, as the
# whole block is moved.
worker.add_remote_agent(meta, remote_tp_size=1)
# NOTE: resource cleanup in mp backend is a bit finicky, so the order in which # NOTE: resource cleanup in mp backend is a bit finicky, so the order in which
# we put here is important. First run ray, it will clean up the resources, then # we put here is important. First run ray, it will clean up the resources, then

View File

@ -83,6 +83,7 @@ def create_vllm_config(
block_size: int = 16, block_size: int = 16,
max_model_len: int = 10000, max_model_len: int = 10000,
enable_chunked_prefill: bool = True, enable_chunked_prefill: bool = True,
enable_permute_local_kv: bool = False,
) -> VllmConfig: ) -> VllmConfig:
"""Initialize VllmConfig For Testing.""" """Initialize VllmConfig For Testing."""
scheduler_config = SchedulerConfig( scheduler_config = SchedulerConfig(
@ -108,6 +109,7 @@ def create_vllm_config(
kv_transfer_config = KVTransferConfig( kv_transfer_config = KVTransferConfig(
kv_connector="NixlConnector", kv_connector="NixlConnector",
kv_role="kv_both", kv_role="kv_both",
enable_permute_local_kv=enable_permute_local_kv,
) )
return VllmConfig( return VllmConfig(
scheduler_config=scheduler_config, scheduler_config=scheduler_config,

View File

@ -61,6 +61,9 @@ class KVTransferConfig:
"""The Python module path to dynamically load the KV connector from. """The Python module path to dynamically load the KV connector from.
Only supported in V1.""" Only supported in V1."""
enable_permute_local_kv: bool = False
"""Experiment feature flag to enable HND to NHD KV Transfer"""
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
WARNING: Whenever a new field is added to this config, WARNING: Whenever a new field is added to this config,

View File

@ -563,6 +563,7 @@ class NixlConnectorWorker:
self.world_size = get_tensor_model_parallel_world_size() self.world_size = get_tensor_model_parallel_world_size()
self.tp_group = get_tp_group() self.tp_group = get_tp_group()
self.num_blocks = 0 self.num_blocks = 0
self.enable_permute_local_kv = False
# KV Caches and nixl tracking data. # KV Caches and nixl tracking data.
self.device_type = current_platform.device_type self.device_type = current_platform.device_type
@ -1094,6 +1095,23 @@ class NixlConnectorWorker:
is_kv_replicated = self._tp_size[engine_id] // total_num_kv_heads >= 1 is_kv_replicated = self._tp_size[engine_id] // total_num_kv_heads >= 1
remote_block_len = nixl_agent_meta.block_lens[0] remote_block_len = nixl_agent_meta.block_lens[0]
if nixl_agent_meta.kv_cache_layout != self.kv_cache_layout:
if (
self.vllm_config.kv_transfer_config is not None
and self.vllm_config.kv_transfer_config.enable_permute_local_kv
and nixl_agent_meta.kv_cache_layout == "HND"
):
logger.info(
"Remote is HND and local is NHD, enabled additional permute "
"on local device KV."
)
self.enable_permute_local_kv = True
else:
raise RuntimeError(
"Heterogeneous TP expects same kv_cache_layout. "
"Or enable experimental feature to use HND to NHD support by "
"setting 'enable_permute_local_kv'=True in --kv-transfer-config."
)
if self.use_mla or is_kv_replicated: if self.use_mla or is_kv_replicated:
# With replicated KV cache, only the number of blocks can differ. # With replicated KV cache, only the number of blocks can differ.
assert self.block_len_per_layer == nixl_agent_meta.block_lens, ( assert self.block_len_per_layer == nixl_agent_meta.block_lens, (
@ -1114,7 +1132,10 @@ class NixlConnectorWorker:
remote_block_size //= 2 remote_block_size //= 2
if tp_ratio > 1: if tp_ratio > 1:
# Heterogeneous TP expects same kv_cache_layout. # Heterogeneous TP expects same kv_cache_layout.
assert nixl_agent_meta.kv_cache_layout == self.kv_cache_layout if nixl_agent_meta.kv_cache_layout == "NHD":
raise ValueError(
"Heterogeneous TP is not supported for remote with NHD."
)
if self.device_type == "xpu": if self.device_type == "xpu":
raise ValueError("Heterogeneous TP is not supported on XPU") raise ValueError("Heterogeneous TP is not supported on XPU")
@ -1226,6 +1247,41 @@ class NixlConnectorWorker:
"d2h", "d2h",
) )
def permute_device_kv(self, block_ids: list[int]):
"""Transforms the layout of received KV cache blocks to the local format.
This method corrects layout mismatches from direct memory copies by
permuting the tensor dimensions.
- **Source Layout:** `[num_blocks, n_kv_head, block_size, head_dim]`
- **Target Layout:** `[num_blocks, block_size, n_kv_head, head_dim]`
Args:
block_ids: A list of block IDs to update and permute.
Implementation:
- x = blocks_to_update.reshape(src_shape) # view local kv with sender layout
- permuted_blocks = x.permute(*inv_order) # transpose n_kv_heads, block_size
- cache.index_copy_(0, indices, permuted_blocks) # copy permuted kv back
"""
split_k_and_v = not (self.use_mla or self._use_pallas or self._use_flashinfer)
inv_order = [0, 2, 1, 3]
sample_cache = list(self.device_kv_caches.values())[0][0]
target_shape = list(sample_cache.shape)
target_shape[0] = -1
src_shape = tuple(target_shape[i] for i in inv_order)
indices = torch.tensor(block_ids, device=sample_cache.device)
for _, cache_or_caches in self.device_kv_caches.items():
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
for cache in cache_list:
blocks_to_update = cache.index_select(0, indices)
permuted_blocks = blocks_to_update.reshape(src_shape).permute(
*inv_order
)
cache.index_copy_(0, indices, permuted_blocks)
def get_finished(self) -> tuple[set[str], set[str]]: def get_finished(self) -> tuple[set[str], set[str]]:
""" """
Get requests that are done sending or recving on this specific worker. Get requests that are done sending or recving on this specific worker.
@ -1273,6 +1329,15 @@ class NixlConnectorWorker:
del self._reqs_to_send[req_id] del self._reqs_to_send[req_id]
done_sending.add(req_id) done_sending.add(req_id)
if self.enable_permute_local_kv and len(done_recving) > 0:
block_ids = []
for req_id in done_recving:
meta = self._recving_metadata.pop(req_id)
assert meta, f"{req_id} not found in recving_metadata list"
block_ids += meta.local_block_ids
self.permute_device_kv(block_ids)
return done_sending, done_recving return done_sending, done_recving
def _get_new_notifs(self) -> set[str]: def _get_new_notifs(self) -> set[str]: