mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:15:51 +08:00
[NIXL][HeteroTP] Enable KV transfer from HND prefill to NHD decode (#26556)
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
This commit is contained in:
parent
74704d4553
commit
7e6edb1469
@ -156,6 +156,16 @@ python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py \
|
||||
NixlConnector currently does not distinguish `kv_role`; the actual prefiller/decoder roles are determined by the upper-level proxy (e.g., `toy_proxy_server.py` using `--prefiller-hosts` and `--decoder-hosts`).
|
||||
Therefore, `kv_role` in `--kv-transfer-config` is effectively a placeholder and does not affect NixlConnector's behavior.
|
||||
|
||||
## Experimental Feature
|
||||
|
||||
### Heterogenuous KV Layout support
|
||||
|
||||
Support use case: Prefill with 'HND' and decode with 'NHD' with experimental configuration
|
||||
|
||||
```bash
|
||||
--kv-transfer-config '{..., "enable_permute_local_kv":"True"}'
|
||||
```
|
||||
|
||||
## Example Scripts/Code
|
||||
|
||||
Refer to these example scripts in the vLLM repository:
|
||||
|
||||
@ -19,11 +19,18 @@ done
|
||||
|
||||
echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE"
|
||||
|
||||
DECODER_KV_LAYOUT=${DECODER_KV_LAYOUT:-"HND"} # Default to HND, optional NHD
|
||||
if [[ "$DECODER_KV_LAYOUT" == "NHD" ]]; then
|
||||
KV_CONFIG_HETERO_LAYOUT=',"enable_permute_local_kv":"True"'
|
||||
else
|
||||
KV_CONFIG_HETERO_LAYOUT=''
|
||||
fi
|
||||
|
||||
# Build the kv-transfer-config once
|
||||
if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then
|
||||
KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
|
||||
KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${KV_CONFIG_HETERO_LAYOUT}'}'
|
||||
else
|
||||
KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\"}"
|
||||
KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}"}"
|
||||
fi
|
||||
|
||||
# Models to run
|
||||
@ -117,6 +124,7 @@ run_tests_for_model() {
|
||||
|
||||
# Build the command with or without model-specific args
|
||||
BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID \
|
||||
VLLM_KV_CACHE_LAYOUT='HND' \
|
||||
UCX_NET_DEVICES=all \
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \
|
||||
vllm serve $model_name \
|
||||
@ -157,6 +165,7 @@ run_tests_for_model() {
|
||||
|
||||
# Build the command with or without model-specific args
|
||||
BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID \
|
||||
VLLM_KV_CACHE_LAYOUT=$DECODER_KV_LAYOUT \
|
||||
UCX_NET_DEVICES=all \
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \
|
||||
vllm serve $model_name \
|
||||
|
||||
@ -286,9 +286,12 @@ def test_prompt_less_than_block_size():
|
||||
class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||
REMOTE_ENGINE_ID = "remote_engine"
|
||||
|
||||
def __init__(self, *args, hand_shake_latency: float = 1.8, **kwargs):
|
||||
def __init__(
|
||||
self, *args, hand_shake_latency: float = 1.8, kv_cache_layout="HND", **kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._hand_shake_latency = hand_shake_latency
|
||||
self.kv_cache_layout = kv_cache_layout
|
||||
|
||||
def _nixl_handshake(
|
||||
self, host: str, port: int, remote_tp_size: int, expected_engine_id: str
|
||||
@ -564,10 +567,63 @@ class TestNixlHandshake:
|
||||
|
||||
# We don't check layout for homogeneous TP and MLA for now, as the
|
||||
# whole block is moved.
|
||||
worker.add_remote_agent(meta, remote_tp_size=2)
|
||||
with pytest.raises(RuntimeError):
|
||||
# mismatched layout is expected to fail
|
||||
worker.add_remote_agent(meta, remote_tp_size=2)
|
||||
with pytest.raises(AssertionError):
|
||||
worker.add_remote_agent(meta, remote_tp_size=1)
|
||||
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FakeNixlWrapper,
|
||||
)
|
||||
def test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental(
|
||||
self, dist_init
|
||||
):
|
||||
"""
|
||||
Verify that adding a remote agent fails if kv_cache_layout differs.
|
||||
This test is only relevant for heterogeneous TP.
|
||||
"""
|
||||
vllm_config = create_vllm_config(enable_permute_local_kv=True)
|
||||
|
||||
# Mock TP world size to 2 to force heterogeneous TP when
|
||||
# remote_tp_size=1
|
||||
with patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.get_tensor_model_parallel_world_size", # noqa: E501
|
||||
return_value=2,
|
||||
):
|
||||
# Initialize connector and worker (with fake NIXL wrapper)
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(
|
||||
vllm_config,
|
||||
connector.engine_id,
|
||||
hand_shake_latency=0,
|
||||
kv_cache_layout="NHD",
|
||||
)
|
||||
worker = connector.connector_worker
|
||||
|
||||
# Minimal local registration params used by add_remote_agent
|
||||
worker.slot_size_per_layer = [2048]
|
||||
worker.block_len_per_layer = [2048 * worker.block_size]
|
||||
worker.num_blocks = 1
|
||||
worker.dst_num_blocks[worker.engine_id] = worker.num_blocks
|
||||
|
||||
# Metadata with different kv_cache_layout than local worker
|
||||
meta = NixlAgentMetadata(
|
||||
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
||||
kv_caches_base_addr=[0],
|
||||
num_blocks=1,
|
||||
# prefill TP=1, decode TP=2, remote block_lens is double to local
|
||||
block_lens=[i * 2 for i in worker.block_len_per_layer],
|
||||
attn_backend_name=worker.backend_name,
|
||||
kv_cache_layout="HND",
|
||||
)
|
||||
|
||||
# We don't check layout for homogeneous TP and MLA for now, as the
|
||||
# whole block is moved.
|
||||
worker.add_remote_agent(meta, remote_tp_size=1)
|
||||
|
||||
|
||||
# NOTE: resource cleanup in mp backend is a bit finicky, so the order in which
|
||||
# we put here is important. First run ray, it will clean up the resources, then
|
||||
|
||||
@ -83,6 +83,7 @@ def create_vllm_config(
|
||||
block_size: int = 16,
|
||||
max_model_len: int = 10000,
|
||||
enable_chunked_prefill: bool = True,
|
||||
enable_permute_local_kv: bool = False,
|
||||
) -> VllmConfig:
|
||||
"""Initialize VllmConfig For Testing."""
|
||||
scheduler_config = SchedulerConfig(
|
||||
@ -108,6 +109,7 @@ def create_vllm_config(
|
||||
kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="NixlConnector",
|
||||
kv_role="kv_both",
|
||||
enable_permute_local_kv=enable_permute_local_kv,
|
||||
)
|
||||
return VllmConfig(
|
||||
scheduler_config=scheduler_config,
|
||||
|
||||
@ -61,6 +61,9 @@ class KVTransferConfig:
|
||||
"""The Python module path to dynamically load the KV connector from.
|
||||
Only supported in V1."""
|
||||
|
||||
enable_permute_local_kv: bool = False
|
||||
"""Experiment feature flag to enable HND to NHD KV Transfer"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
|
||||
@ -563,6 +563,7 @@ class NixlConnectorWorker:
|
||||
self.world_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_group = get_tp_group()
|
||||
self.num_blocks = 0
|
||||
self.enable_permute_local_kv = False
|
||||
|
||||
# KV Caches and nixl tracking data.
|
||||
self.device_type = current_platform.device_type
|
||||
@ -1094,6 +1095,23 @@ class NixlConnectorWorker:
|
||||
is_kv_replicated = self._tp_size[engine_id] // total_num_kv_heads >= 1
|
||||
|
||||
remote_block_len = nixl_agent_meta.block_lens[0]
|
||||
if nixl_agent_meta.kv_cache_layout != self.kv_cache_layout:
|
||||
if (
|
||||
self.vllm_config.kv_transfer_config is not None
|
||||
and self.vllm_config.kv_transfer_config.enable_permute_local_kv
|
||||
and nixl_agent_meta.kv_cache_layout == "HND"
|
||||
):
|
||||
logger.info(
|
||||
"Remote is HND and local is NHD, enabled additional permute "
|
||||
"on local device KV."
|
||||
)
|
||||
self.enable_permute_local_kv = True
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Heterogeneous TP expects same kv_cache_layout. "
|
||||
"Or enable experimental feature to use HND to NHD support by "
|
||||
"setting 'enable_permute_local_kv'=True in --kv-transfer-config."
|
||||
)
|
||||
if self.use_mla or is_kv_replicated:
|
||||
# With replicated KV cache, only the number of blocks can differ.
|
||||
assert self.block_len_per_layer == nixl_agent_meta.block_lens, (
|
||||
@ -1114,7 +1132,10 @@ class NixlConnectorWorker:
|
||||
remote_block_size //= 2
|
||||
if tp_ratio > 1:
|
||||
# Heterogeneous TP expects same kv_cache_layout.
|
||||
assert nixl_agent_meta.kv_cache_layout == self.kv_cache_layout
|
||||
if nixl_agent_meta.kv_cache_layout == "NHD":
|
||||
raise ValueError(
|
||||
"Heterogeneous TP is not supported for remote with NHD."
|
||||
)
|
||||
if self.device_type == "xpu":
|
||||
raise ValueError("Heterogeneous TP is not supported on XPU")
|
||||
|
||||
@ -1226,6 +1247,41 @@ class NixlConnectorWorker:
|
||||
"d2h",
|
||||
)
|
||||
|
||||
def permute_device_kv(self, block_ids: list[int]):
|
||||
"""Transforms the layout of received KV cache blocks to the local format.
|
||||
|
||||
This method corrects layout mismatches from direct memory copies by
|
||||
permuting the tensor dimensions.
|
||||
|
||||
- **Source Layout:** `[num_blocks, n_kv_head, block_size, head_dim]`
|
||||
- **Target Layout:** `[num_blocks, block_size, n_kv_head, head_dim]`
|
||||
|
||||
Args:
|
||||
block_ids: A list of block IDs to update and permute.
|
||||
|
||||
Implementation:
|
||||
- x = blocks_to_update.reshape(src_shape) # view local kv with sender layout
|
||||
- permuted_blocks = x.permute(*inv_order) # transpose n_kv_heads, block_size
|
||||
- cache.index_copy_(0, indices, permuted_blocks) # copy permuted kv back
|
||||
|
||||
"""
|
||||
split_k_and_v = not (self.use_mla or self._use_pallas or self._use_flashinfer)
|
||||
inv_order = [0, 2, 1, 3]
|
||||
sample_cache = list(self.device_kv_caches.values())[0][0]
|
||||
target_shape = list(sample_cache.shape)
|
||||
target_shape[0] = -1
|
||||
src_shape = tuple(target_shape[i] for i in inv_order)
|
||||
indices = torch.tensor(block_ids, device=sample_cache.device)
|
||||
|
||||
for _, cache_or_caches in self.device_kv_caches.items():
|
||||
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
|
||||
for cache in cache_list:
|
||||
blocks_to_update = cache.index_select(0, indices)
|
||||
permuted_blocks = blocks_to_update.reshape(src_shape).permute(
|
||||
*inv_order
|
||||
)
|
||||
cache.index_copy_(0, indices, permuted_blocks)
|
||||
|
||||
def get_finished(self) -> tuple[set[str], set[str]]:
|
||||
"""
|
||||
Get requests that are done sending or recving on this specific worker.
|
||||
@ -1273,6 +1329,15 @@ class NixlConnectorWorker:
|
||||
del self._reqs_to_send[req_id]
|
||||
done_sending.add(req_id)
|
||||
|
||||
if self.enable_permute_local_kv and len(done_recving) > 0:
|
||||
block_ids = []
|
||||
for req_id in done_recving:
|
||||
meta = self._recving_metadata.pop(req_id)
|
||||
assert meta, f"{req_id} not found in recving_metadata list"
|
||||
block_ids += meta.local_block_ids
|
||||
|
||||
self.permute_device_kv(block_ids)
|
||||
|
||||
return done_sending, done_recving
|
||||
|
||||
def _get_new_notifs(self) -> set[str]:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user