mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 05:42:16 +08:00
[NIXL][HeteroTP] Enable KV transfer from HND prefill to NHD decode (#26556)
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
This commit is contained in:
parent
74704d4553
commit
7e6edb1469
@ -156,6 +156,16 @@ python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py \
|
|||||||
NixlConnector currently does not distinguish `kv_role`; the actual prefiller/decoder roles are determined by the upper-level proxy (e.g., `toy_proxy_server.py` using `--prefiller-hosts` and `--decoder-hosts`).
|
NixlConnector currently does not distinguish `kv_role`; the actual prefiller/decoder roles are determined by the upper-level proxy (e.g., `toy_proxy_server.py` using `--prefiller-hosts` and `--decoder-hosts`).
|
||||||
Therefore, `kv_role` in `--kv-transfer-config` is effectively a placeholder and does not affect NixlConnector's behavior.
|
Therefore, `kv_role` in `--kv-transfer-config` is effectively a placeholder and does not affect NixlConnector's behavior.
|
||||||
|
|
||||||
|
## Experimental Feature
|
||||||
|
|
||||||
|
### Heterogenuous KV Layout support
|
||||||
|
|
||||||
|
Support use case: Prefill with 'HND' and decode with 'NHD' with experimental configuration
|
||||||
|
|
||||||
|
```bash
|
||||||
|
--kv-transfer-config '{..., "enable_permute_local_kv":"True"}'
|
||||||
|
```
|
||||||
|
|
||||||
## Example Scripts/Code
|
## Example Scripts/Code
|
||||||
|
|
||||||
Refer to these example scripts in the vLLM repository:
|
Refer to these example scripts in the vLLM repository:
|
||||||
|
|||||||
@ -19,11 +19,18 @@ done
|
|||||||
|
|
||||||
echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE"
|
echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE"
|
||||||
|
|
||||||
|
DECODER_KV_LAYOUT=${DECODER_KV_LAYOUT:-"HND"} # Default to HND, optional NHD
|
||||||
|
if [[ "$DECODER_KV_LAYOUT" == "NHD" ]]; then
|
||||||
|
KV_CONFIG_HETERO_LAYOUT=',"enable_permute_local_kv":"True"'
|
||||||
|
else
|
||||||
|
KV_CONFIG_HETERO_LAYOUT=''
|
||||||
|
fi
|
||||||
|
|
||||||
# Build the kv-transfer-config once
|
# Build the kv-transfer-config once
|
||||||
if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then
|
if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then
|
||||||
KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
|
KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${KV_CONFIG_HETERO_LAYOUT}'}'
|
||||||
else
|
else
|
||||||
KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\"}"
|
KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}"}"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Models to run
|
# Models to run
|
||||||
@ -117,6 +124,7 @@ run_tests_for_model() {
|
|||||||
|
|
||||||
# Build the command with or without model-specific args
|
# Build the command with or without model-specific args
|
||||||
BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID \
|
BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID \
|
||||||
|
VLLM_KV_CACHE_LAYOUT='HND' \
|
||||||
UCX_NET_DEVICES=all \
|
UCX_NET_DEVICES=all \
|
||||||
VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \
|
VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \
|
||||||
vllm serve $model_name \
|
vllm serve $model_name \
|
||||||
@ -157,6 +165,7 @@ run_tests_for_model() {
|
|||||||
|
|
||||||
# Build the command with or without model-specific args
|
# Build the command with or without model-specific args
|
||||||
BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID \
|
BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID \
|
||||||
|
VLLM_KV_CACHE_LAYOUT=$DECODER_KV_LAYOUT \
|
||||||
UCX_NET_DEVICES=all \
|
UCX_NET_DEVICES=all \
|
||||||
VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \
|
VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \
|
||||||
vllm serve $model_name \
|
vllm serve $model_name \
|
||||||
|
|||||||
@ -286,9 +286,12 @@ def test_prompt_less_than_block_size():
|
|||||||
class FakeNixlConnectorWorker(NixlConnectorWorker):
|
class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||||
REMOTE_ENGINE_ID = "remote_engine"
|
REMOTE_ENGINE_ID = "remote_engine"
|
||||||
|
|
||||||
def __init__(self, *args, hand_shake_latency: float = 1.8, **kwargs):
|
def __init__(
|
||||||
|
self, *args, hand_shake_latency: float = 1.8, kv_cache_layout="HND", **kwargs
|
||||||
|
):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self._hand_shake_latency = hand_shake_latency
|
self._hand_shake_latency = hand_shake_latency
|
||||||
|
self.kv_cache_layout = kv_cache_layout
|
||||||
|
|
||||||
def _nixl_handshake(
|
def _nixl_handshake(
|
||||||
self, host: str, port: int, remote_tp_size: int, expected_engine_id: str
|
self, host: str, port: int, remote_tp_size: int, expected_engine_id: str
|
||||||
@ -564,10 +567,63 @@ class TestNixlHandshake:
|
|||||||
|
|
||||||
# We don't check layout for homogeneous TP and MLA for now, as the
|
# We don't check layout for homogeneous TP and MLA for now, as the
|
||||||
# whole block is moved.
|
# whole block is moved.
|
||||||
worker.add_remote_agent(meta, remote_tp_size=2)
|
with pytest.raises(RuntimeError):
|
||||||
|
# mismatched layout is expected to fail
|
||||||
|
worker.add_remote_agent(meta, remote_tp_size=2)
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
worker.add_remote_agent(meta, remote_tp_size=1)
|
worker.add_remote_agent(meta, remote_tp_size=1)
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||||
|
FakeNixlWrapper,
|
||||||
|
)
|
||||||
|
def test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental(
|
||||||
|
self, dist_init
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Verify that adding a remote agent fails if kv_cache_layout differs.
|
||||||
|
This test is only relevant for heterogeneous TP.
|
||||||
|
"""
|
||||||
|
vllm_config = create_vllm_config(enable_permute_local_kv=True)
|
||||||
|
|
||||||
|
# Mock TP world size to 2 to force heterogeneous TP when
|
||||||
|
# remote_tp_size=1
|
||||||
|
with patch(
|
||||||
|
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.get_tensor_model_parallel_world_size", # noqa: E501
|
||||||
|
return_value=2,
|
||||||
|
):
|
||||||
|
# Initialize connector and worker (with fake NIXL wrapper)
|
||||||
|
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||||
|
connector.connector_worker = FakeNixlConnectorWorker(
|
||||||
|
vllm_config,
|
||||||
|
connector.engine_id,
|
||||||
|
hand_shake_latency=0,
|
||||||
|
kv_cache_layout="NHD",
|
||||||
|
)
|
||||||
|
worker = connector.connector_worker
|
||||||
|
|
||||||
|
# Minimal local registration params used by add_remote_agent
|
||||||
|
worker.slot_size_per_layer = [2048]
|
||||||
|
worker.block_len_per_layer = [2048 * worker.block_size]
|
||||||
|
worker.num_blocks = 1
|
||||||
|
worker.dst_num_blocks[worker.engine_id] = worker.num_blocks
|
||||||
|
|
||||||
|
# Metadata with different kv_cache_layout than local worker
|
||||||
|
meta = NixlAgentMetadata(
|
||||||
|
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||||
|
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
||||||
|
kv_caches_base_addr=[0],
|
||||||
|
num_blocks=1,
|
||||||
|
# prefill TP=1, decode TP=2, remote block_lens is double to local
|
||||||
|
block_lens=[i * 2 for i in worker.block_len_per_layer],
|
||||||
|
attn_backend_name=worker.backend_name,
|
||||||
|
kv_cache_layout="HND",
|
||||||
|
)
|
||||||
|
|
||||||
|
# We don't check layout for homogeneous TP and MLA for now, as the
|
||||||
|
# whole block is moved.
|
||||||
|
worker.add_remote_agent(meta, remote_tp_size=1)
|
||||||
|
|
||||||
|
|
||||||
# NOTE: resource cleanup in mp backend is a bit finicky, so the order in which
|
# NOTE: resource cleanup in mp backend is a bit finicky, so the order in which
|
||||||
# we put here is important. First run ray, it will clean up the resources, then
|
# we put here is important. First run ray, it will clean up the resources, then
|
||||||
|
|||||||
@ -83,6 +83,7 @@ def create_vllm_config(
|
|||||||
block_size: int = 16,
|
block_size: int = 16,
|
||||||
max_model_len: int = 10000,
|
max_model_len: int = 10000,
|
||||||
enable_chunked_prefill: bool = True,
|
enable_chunked_prefill: bool = True,
|
||||||
|
enable_permute_local_kv: bool = False,
|
||||||
) -> VllmConfig:
|
) -> VllmConfig:
|
||||||
"""Initialize VllmConfig For Testing."""
|
"""Initialize VllmConfig For Testing."""
|
||||||
scheduler_config = SchedulerConfig(
|
scheduler_config = SchedulerConfig(
|
||||||
@ -108,6 +109,7 @@ def create_vllm_config(
|
|||||||
kv_transfer_config = KVTransferConfig(
|
kv_transfer_config = KVTransferConfig(
|
||||||
kv_connector="NixlConnector",
|
kv_connector="NixlConnector",
|
||||||
kv_role="kv_both",
|
kv_role="kv_both",
|
||||||
|
enable_permute_local_kv=enable_permute_local_kv,
|
||||||
)
|
)
|
||||||
return VllmConfig(
|
return VllmConfig(
|
||||||
scheduler_config=scheduler_config,
|
scheduler_config=scheduler_config,
|
||||||
|
|||||||
@ -61,6 +61,9 @@ class KVTransferConfig:
|
|||||||
"""The Python module path to dynamically load the KV connector from.
|
"""The Python module path to dynamically load the KV connector from.
|
||||||
Only supported in V1."""
|
Only supported in V1."""
|
||||||
|
|
||||||
|
enable_permute_local_kv: bool = False
|
||||||
|
"""Experiment feature flag to enable HND to NHD KV Transfer"""
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
"""
|
"""
|
||||||
WARNING: Whenever a new field is added to this config,
|
WARNING: Whenever a new field is added to this config,
|
||||||
|
|||||||
@ -563,6 +563,7 @@ class NixlConnectorWorker:
|
|||||||
self.world_size = get_tensor_model_parallel_world_size()
|
self.world_size = get_tensor_model_parallel_world_size()
|
||||||
self.tp_group = get_tp_group()
|
self.tp_group = get_tp_group()
|
||||||
self.num_blocks = 0
|
self.num_blocks = 0
|
||||||
|
self.enable_permute_local_kv = False
|
||||||
|
|
||||||
# KV Caches and nixl tracking data.
|
# KV Caches and nixl tracking data.
|
||||||
self.device_type = current_platform.device_type
|
self.device_type = current_platform.device_type
|
||||||
@ -1094,6 +1095,23 @@ class NixlConnectorWorker:
|
|||||||
is_kv_replicated = self._tp_size[engine_id] // total_num_kv_heads >= 1
|
is_kv_replicated = self._tp_size[engine_id] // total_num_kv_heads >= 1
|
||||||
|
|
||||||
remote_block_len = nixl_agent_meta.block_lens[0]
|
remote_block_len = nixl_agent_meta.block_lens[0]
|
||||||
|
if nixl_agent_meta.kv_cache_layout != self.kv_cache_layout:
|
||||||
|
if (
|
||||||
|
self.vllm_config.kv_transfer_config is not None
|
||||||
|
and self.vllm_config.kv_transfer_config.enable_permute_local_kv
|
||||||
|
and nixl_agent_meta.kv_cache_layout == "HND"
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
"Remote is HND and local is NHD, enabled additional permute "
|
||||||
|
"on local device KV."
|
||||||
|
)
|
||||||
|
self.enable_permute_local_kv = True
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Heterogeneous TP expects same kv_cache_layout. "
|
||||||
|
"Or enable experimental feature to use HND to NHD support by "
|
||||||
|
"setting 'enable_permute_local_kv'=True in --kv-transfer-config."
|
||||||
|
)
|
||||||
if self.use_mla or is_kv_replicated:
|
if self.use_mla or is_kv_replicated:
|
||||||
# With replicated KV cache, only the number of blocks can differ.
|
# With replicated KV cache, only the number of blocks can differ.
|
||||||
assert self.block_len_per_layer == nixl_agent_meta.block_lens, (
|
assert self.block_len_per_layer == nixl_agent_meta.block_lens, (
|
||||||
@ -1114,7 +1132,10 @@ class NixlConnectorWorker:
|
|||||||
remote_block_size //= 2
|
remote_block_size //= 2
|
||||||
if tp_ratio > 1:
|
if tp_ratio > 1:
|
||||||
# Heterogeneous TP expects same kv_cache_layout.
|
# Heterogeneous TP expects same kv_cache_layout.
|
||||||
assert nixl_agent_meta.kv_cache_layout == self.kv_cache_layout
|
if nixl_agent_meta.kv_cache_layout == "NHD":
|
||||||
|
raise ValueError(
|
||||||
|
"Heterogeneous TP is not supported for remote with NHD."
|
||||||
|
)
|
||||||
if self.device_type == "xpu":
|
if self.device_type == "xpu":
|
||||||
raise ValueError("Heterogeneous TP is not supported on XPU")
|
raise ValueError("Heterogeneous TP is not supported on XPU")
|
||||||
|
|
||||||
@ -1226,6 +1247,41 @@ class NixlConnectorWorker:
|
|||||||
"d2h",
|
"d2h",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def permute_device_kv(self, block_ids: list[int]):
|
||||||
|
"""Transforms the layout of received KV cache blocks to the local format.
|
||||||
|
|
||||||
|
This method corrects layout mismatches from direct memory copies by
|
||||||
|
permuting the tensor dimensions.
|
||||||
|
|
||||||
|
- **Source Layout:** `[num_blocks, n_kv_head, block_size, head_dim]`
|
||||||
|
- **Target Layout:** `[num_blocks, block_size, n_kv_head, head_dim]`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
block_ids: A list of block IDs to update and permute.
|
||||||
|
|
||||||
|
Implementation:
|
||||||
|
- x = blocks_to_update.reshape(src_shape) # view local kv with sender layout
|
||||||
|
- permuted_blocks = x.permute(*inv_order) # transpose n_kv_heads, block_size
|
||||||
|
- cache.index_copy_(0, indices, permuted_blocks) # copy permuted kv back
|
||||||
|
|
||||||
|
"""
|
||||||
|
split_k_and_v = not (self.use_mla or self._use_pallas or self._use_flashinfer)
|
||||||
|
inv_order = [0, 2, 1, 3]
|
||||||
|
sample_cache = list(self.device_kv_caches.values())[0][0]
|
||||||
|
target_shape = list(sample_cache.shape)
|
||||||
|
target_shape[0] = -1
|
||||||
|
src_shape = tuple(target_shape[i] for i in inv_order)
|
||||||
|
indices = torch.tensor(block_ids, device=sample_cache.device)
|
||||||
|
|
||||||
|
for _, cache_or_caches in self.device_kv_caches.items():
|
||||||
|
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
|
||||||
|
for cache in cache_list:
|
||||||
|
blocks_to_update = cache.index_select(0, indices)
|
||||||
|
permuted_blocks = blocks_to_update.reshape(src_shape).permute(
|
||||||
|
*inv_order
|
||||||
|
)
|
||||||
|
cache.index_copy_(0, indices, permuted_blocks)
|
||||||
|
|
||||||
def get_finished(self) -> tuple[set[str], set[str]]:
|
def get_finished(self) -> tuple[set[str], set[str]]:
|
||||||
"""
|
"""
|
||||||
Get requests that are done sending or recving on this specific worker.
|
Get requests that are done sending or recving on this specific worker.
|
||||||
@ -1273,6 +1329,15 @@ class NixlConnectorWorker:
|
|||||||
del self._reqs_to_send[req_id]
|
del self._reqs_to_send[req_id]
|
||||||
done_sending.add(req_id)
|
done_sending.add(req_id)
|
||||||
|
|
||||||
|
if self.enable_permute_local_kv and len(done_recving) > 0:
|
||||||
|
block_ids = []
|
||||||
|
for req_id in done_recving:
|
||||||
|
meta = self._recving_metadata.pop(req_id)
|
||||||
|
assert meta, f"{req_id} not found in recving_metadata list"
|
||||||
|
block_ids += meta.local_block_ids
|
||||||
|
|
||||||
|
self.permute_device_kv(block_ids)
|
||||||
|
|
||||||
return done_sending, done_recving
|
return done_sending, done_recving
|
||||||
|
|
||||||
def _get_new_notifs(self) -> set[str]:
|
def _get_new_notifs(self) -> set[str]:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user