mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-27 12:57:10 +08:00
Merge remote-tracking branch 'test/nixl-ptp-gt-dtp' into woosuk/router-nixl
This commit is contained in:
commit
9decb2a5b1
@ -34,15 +34,21 @@ else
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
# Models to run
|
# Models to run
|
||||||
MODELS=(
|
MODEL_NAMES=${MODEL_NAMES:-}
|
||||||
"Qwen/Qwen3-0.6B"
|
if [[ -n "$MODEL_NAMES" ]]; then
|
||||||
)
|
MODELS=("$MODEL_NAMES")
|
||||||
|
else
|
||||||
|
MODELS=(
|
||||||
|
"Qwen/Qwen3-0.6B"
|
||||||
|
)
|
||||||
|
fi
|
||||||
|
|
||||||
# Number of prefill and decode instances to create
|
# Number of prefill and decode instances to create
|
||||||
NUM_PREFILL_INSTANCES=${NUM_PREFILL_INSTANCES:-1} # Default to 1
|
NUM_PREFILL_INSTANCES=${NUM_PREFILL_INSTANCES:-1} # Default to 1
|
||||||
NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-1} # Default to 1
|
NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-1} # Default to 1
|
||||||
PREFILLER_TP_SIZE=${PREFILLER_TP_SIZE:-1}
|
PREFILLER_TP_SIZE=${PREFILLER_TP_SIZE:-1}
|
||||||
DECODER_TP_SIZE=${DECODER_TP_SIZE:-1}
|
DECODER_TP_SIZE=${DECODER_TP_SIZE:-1}
|
||||||
|
GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.2}
|
||||||
|
|
||||||
# Find the git repository root directory
|
# Find the git repository root directory
|
||||||
GIT_ROOT=$(git rev-parse --show-toplevel)
|
GIT_ROOT=$(git rev-parse --show-toplevel)
|
||||||
@ -130,7 +136,7 @@ run_tests_for_model() {
|
|||||||
vllm serve $model_name \
|
vllm serve $model_name \
|
||||||
--port $PORT \
|
--port $PORT \
|
||||||
--enforce-eager \
|
--enforce-eager \
|
||||||
--gpu-memory-utilization 0.2 \
|
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
|
||||||
--tensor-parallel-size $PREFILLER_TP_SIZE \
|
--tensor-parallel-size $PREFILLER_TP_SIZE \
|
||||||
--kv-transfer-config '$KV_CONFIG'"
|
--kv-transfer-config '$KV_CONFIG'"
|
||||||
|
|
||||||
@ -171,7 +177,7 @@ run_tests_for_model() {
|
|||||||
vllm serve $model_name \
|
vllm serve $model_name \
|
||||||
--port $PORT \
|
--port $PORT \
|
||||||
--enforce-eager \
|
--enforce-eager \
|
||||||
--gpu-memory-utilization 0.2 \
|
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
|
||||||
--tensor-parallel-size $DECODER_TP_SIZE \
|
--tensor-parallel-size $DECODER_TP_SIZE \
|
||||||
--kv-transfer-config '$KV_CONFIG'"
|
--kv-transfer-config '$KV_CONFIG'"
|
||||||
|
|
||||||
|
|||||||
@ -12,7 +12,11 @@ FILTER = "exact_match,strict-match"
|
|||||||
RTOL = 0.03
|
RTOL = 0.03
|
||||||
|
|
||||||
# Model-specific expected values
|
# Model-specific expected values
|
||||||
EXPECTED_VALUES = {"Qwen/Qwen3-0.6B": 0.41, "deepseek-ai/deepseek-vl2-small": 0.59}
|
EXPECTED_VALUES = {
|
||||||
|
"Qwen/Qwen3-0.6B": 0.41,
|
||||||
|
"deepseek-ai/deepseek-vl2-small": 0.59,
|
||||||
|
"deepseek-ai/DeepSeek-V2-Lite-Chat": 0.65,
|
||||||
|
}
|
||||||
|
|
||||||
SIMPLE_PROMPT = (
|
SIMPLE_PROMPT = (
|
||||||
"The best part about working on vLLM is that I got to meet so many people across "
|
"The best part about working on vLLM is that I got to meet so many people across "
|
||||||
|
|||||||
43
tests/v1/kv_connector/nixl_integration/tp_config_sweep_accuracy_test.sh
Executable file
43
tests/v1/kv_connector/nixl_integration/tp_config_sweep_accuracy_test.sh
Executable file
@ -0,0 +1,43 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
# Utility to run integration tests sequentially with varying TP configurations.
|
||||||
|
# If FLASHINFER is set, reruns all tests with VLLM_ATTENTION_BACKEND=FLASHINFER.
|
||||||
|
|
||||||
|
SCRIPT="tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh"
|
||||||
|
|
||||||
|
# Define test configurations
|
||||||
|
configs=(
|
||||||
|
"PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2"
|
||||||
|
"PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2"
|
||||||
|
"PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1"
|
||||||
|
"GPU_MEMORY_UTILIZATION=0.6 MODEL_NAMES=deepseek-ai/DeepSeek-V2-Lite-Chat" # MLA case
|
||||||
|
# TP greater than num heads
|
||||||
|
)
|
||||||
|
|
||||||
|
run_tests() {
|
||||||
|
local label=$1
|
||||||
|
local extra_env=$2
|
||||||
|
|
||||||
|
echo "=== Running tests (${label}) ==="
|
||||||
|
for cfg in "${configs[@]}"; do
|
||||||
|
echo "-> Running with ${cfg} ${extra_env:+and ${extra_env}}"
|
||||||
|
# Use 'env' to safely set variables without eval
|
||||||
|
if ! env ${extra_env} ${cfg} bash "${SCRIPT}"; then
|
||||||
|
echo "❌ Test failed for config: ${cfg} ${extra_env:+(${extra_env})}"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
echo "✅ All ${label} tests passed!"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Run base tests
|
||||||
|
run_tests "default backend" ""
|
||||||
|
|
||||||
|
# Check if FLASHINFER is set (non-empty)
|
||||||
|
if [[ -n "${FLASHINFER:-}" ]]; then
|
||||||
|
echo "FLASHINFER is set, rerunning with VLLM_ATTENTION_BACKEND=FLASHINFER"
|
||||||
|
run_tests "FLASHINFER backend" "VLLM_ATTENTION_BACKEND=FLASHINFER"
|
||||||
|
else
|
||||||
|
echo "FLASHINFER not set, skipping FLASHINFER runs."
|
||||||
|
fi
|
||||||
@ -308,21 +308,42 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
|
|||||||
|
|
||||||
assert expected_engine_id == self.REMOTE_ENGINE_ID
|
assert expected_engine_id == self.REMOTE_ENGINE_ID
|
||||||
|
|
||||||
remote_agent_name = self.add_remote_agent(
|
# Adjust remote block length metadata to satisfy heterogeneous TP
|
||||||
NixlAgentMetadata(
|
# invariants enforced during handshake validation.
|
||||||
engine_id=self.REMOTE_ENGINE_ID,
|
remote_block_lens = list(self.block_len_per_layer)
|
||||||
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
tp_ratio = self.kv_info.tp_ratio(remote_tp_size=remote_tp_size)
|
||||||
kv_caches_base_addr=[0],
|
if remote_tp_size > self.world_size:
|
||||||
num_blocks=1,
|
# P TP > D TP case, block_len of remote is smaller
|
||||||
block_lens=self.block_len_per_layer,
|
remote_block_lens = [
|
||||||
attn_backend_name=self.backend_name,
|
block_len // (-tp_ratio) for block_len in remote_block_lens
|
||||||
# `self.kv_cache_layout` is only forced to HND when vllm engine
|
]
|
||||||
# is started. We mock HND here.
|
elif remote_tp_size < self.world_size:
|
||||||
kv_cache_layout="HND",
|
remote_block_lens = [
|
||||||
),
|
block_len * tp_ratio for block_len in remote_block_lens
|
||||||
remote_tp_size=remote_tp_size,
|
]
|
||||||
)
|
|
||||||
return {0: remote_agent_name}
|
# When remote tp_size > local tp_size, handshake with multiple
|
||||||
|
# remote ranks.
|
||||||
|
num_hanshakes = 1 if tp_ratio > 0 else -tp_ratio
|
||||||
|
remote_agents: dict[int, str] = {}
|
||||||
|
for remote_tp_rank in range(num_hanshakes):
|
||||||
|
remote_agent_name = self.add_remote_agent(
|
||||||
|
NixlAgentMetadata(
|
||||||
|
engine_id=self.REMOTE_ENGINE_ID,
|
||||||
|
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
||||||
|
kv_caches_base_addr=[0],
|
||||||
|
num_blocks=1,
|
||||||
|
block_lens=remote_block_lens,
|
||||||
|
attn_backend_name=self.backend_name,
|
||||||
|
# `self.kv_cache_layout` is only forced to HND when vllm engine
|
||||||
|
# is started. We mock HND here.
|
||||||
|
kv_cache_layout="HND",
|
||||||
|
),
|
||||||
|
remote_tp_rank=remote_tp_rank,
|
||||||
|
remote_tp_size=remote_tp_size,
|
||||||
|
)
|
||||||
|
remote_agents[remote_tp_rank] = remote_agent_name
|
||||||
|
return remote_agents
|
||||||
|
|
||||||
|
|
||||||
class TestNixlHandshake:
|
class TestNixlHandshake:
|
||||||
@ -353,7 +374,13 @@ class TestNixlHandshake:
|
|||||||
vllm_config, connector.engine_id, hand_shake_latency=0
|
vllm_config, connector.engine_id, hand_shake_latency=0
|
||||||
)
|
)
|
||||||
assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper)
|
assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper)
|
||||||
connector.connector_worker.nixl_wrapper.set_cycles_before_xfer_done(3)
|
worker = connector.connector_worker
|
||||||
|
worker.nixl_wrapper.set_cycles_before_xfer_done(3)
|
||||||
|
# simulate handshake
|
||||||
|
worker.dst_xfer_side_handles = {
|
||||||
|
FakeNixlConnectorWorker.REMOTE_ENGINE_ID: {0: 1}
|
||||||
|
}
|
||||||
|
worker.kv_cache_layout = "HND"
|
||||||
num_xfers = 4
|
num_xfers = 4
|
||||||
while True:
|
while True:
|
||||||
# For the same request_id, initiate multiple xfers across different
|
# For the same request_id, initiate multiple xfers across different
|
||||||
@ -465,6 +492,70 @@ class TestNixlHandshake:
|
|||||||
return
|
return
|
||||||
raise TimeoutError("Took too long to complete async handshake.")
|
raise TimeoutError("Took too long to complete async handshake.")
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||||
|
FakeNixlWrapper,
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("local_tp_size", [1, 2])
|
||||||
|
def test_prefill_tp_size_greater_than_decode_tp_size(
|
||||||
|
self, local_tp_size: int, dist_init
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Verify remote TP > local TP handshake succeeds with different
|
||||||
|
remote configurations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
vllm_config = create_vllm_config()
|
||||||
|
local_tp_size = 1
|
||||||
|
vllm_config.parallel_config.tensor_parallel_size = local_tp_size
|
||||||
|
|
||||||
|
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||||
|
connector.connector_worker = FakeNixlConnectorWorker(
|
||||||
|
vllm_config, connector.engine_id, hand_shake_latency=0
|
||||||
|
)
|
||||||
|
worker = connector.connector_worker
|
||||||
|
|
||||||
|
# Minimal local registration params used by add_remote_agent
|
||||||
|
worker.slot_size_per_layer = [4096]
|
||||||
|
worker.block_len_per_layer = [4096 * worker.block_size]
|
||||||
|
worker.num_blocks = 1
|
||||||
|
worker.dst_num_blocks[worker.engine_id] = worker.num_blocks
|
||||||
|
worker.src_blocks_data = [(0, worker.block_len_per_layer[0], worker.tp_rank)]
|
||||||
|
|
||||||
|
def check_handshake(remote_tp_size: int):
|
||||||
|
tp_ratio = remote_tp_size // local_tp_size
|
||||||
|
assert set(remote_agents.keys()) == set(range(tp_ratio))
|
||||||
|
|
||||||
|
remote_engine_id = worker.REMOTE_ENGINE_ID
|
||||||
|
assert worker._tp_size[remote_engine_id] == remote_tp_size
|
||||||
|
assert -tp_ratio == worker.kv_info.tp_ratio(remote_engine_id)
|
||||||
|
# ensure src_xfer_side_chunked_handles is populated with tpratio chunks
|
||||||
|
assert -tp_ratio in worker.src_xfer_side_chunked_handles
|
||||||
|
assert len(worker.src_xfer_side_chunked_handles[-tp_ratio]) == tp_ratio
|
||||||
|
assert remote_engine_id in worker.dst_xfer_side_handles
|
||||||
|
assert set(worker.dst_xfer_side_handles[remote_engine_id].keys()) == set(
|
||||||
|
range(tp_ratio)
|
||||||
|
)
|
||||||
|
|
||||||
|
remote_agents = worker._nixl_handshake(
|
||||||
|
host="localhost",
|
||||||
|
port=1234,
|
||||||
|
remote_tp_size=2,
|
||||||
|
expected_engine_id=worker.REMOTE_ENGINE_ID,
|
||||||
|
)
|
||||||
|
check_handshake(2)
|
||||||
|
|
||||||
|
# NOTE flexiblity: a second remote with higher number of ranks
|
||||||
|
# is discovered
|
||||||
|
worker.REMOTE_ENGINE_ID = "remote_engine_2"
|
||||||
|
remote_agents = worker._nixl_handshake(
|
||||||
|
host="localhost",
|
||||||
|
port=1234,
|
||||||
|
remote_tp_size=6,
|
||||||
|
expected_engine_id=worker.REMOTE_ENGINE_ID,
|
||||||
|
)
|
||||||
|
check_handshake(6)
|
||||||
|
|
||||||
@patch(
|
@patch(
|
||||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||||
FakeNixlWrapper,
|
FakeNixlWrapper,
|
||||||
@ -565,12 +656,9 @@ class TestNixlHandshake:
|
|||||||
kv_cache_layout=mismatched_layout,
|
kv_cache_layout=mismatched_layout,
|
||||||
)
|
)
|
||||||
|
|
||||||
# We don't check layout for homogeneous TP and MLA for now, as the
|
# Layout check done for both homogeneous and heterogeneous TP.
|
||||||
# whole block is moved.
|
|
||||||
with pytest.raises(RuntimeError):
|
|
||||||
# mismatched layout is expected to fail
|
|
||||||
worker.add_remote_agent(meta, remote_tp_size=2)
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
|
worker.add_remote_agent(meta, remote_tp_size=2)
|
||||||
worker.add_remote_agent(meta, remote_tp_size=1)
|
worker.add_remote_agent(meta, remote_tp_size=1)
|
||||||
|
|
||||||
@patch(
|
@patch(
|
||||||
@ -1180,7 +1268,8 @@ def test_shutdown_cleans_up_resources(dist_init):
|
|||||||
):
|
):
|
||||||
worker._recving_transfers = {"req1": [(123, time.perf_counter())]}
|
worker._recving_transfers = {"req1": [(123, time.perf_counter())]}
|
||||||
worker.src_xfer_side_handle = 456
|
worker.src_xfer_side_handle = 456
|
||||||
worker.dst_xfer_side_handles = {"engine1": 789}
|
worker.src_xfer_side_chunked_handles = {-2: [456]}
|
||||||
|
worker.dst_xfer_side_handles = {"engine1": {0: 789}}
|
||||||
worker._remote_agents = {"engine1": {0: "agent1"}}
|
worker._remote_agents = {"engine1": {0: "agent1"}}
|
||||||
worker._registered_descs = ["desc1", "desc2"]
|
worker._registered_descs = ["desc1", "desc2"]
|
||||||
|
|
||||||
@ -1194,7 +1283,7 @@ def test_shutdown_cleans_up_resources(dist_init):
|
|||||||
mock_listener.join.assert_called_once_with(timeout=0)
|
mock_listener.join.assert_called_once_with(timeout=0)
|
||||||
|
|
||||||
mock_rel_xfer.assert_called_once_with(123)
|
mock_rel_xfer.assert_called_once_with(123)
|
||||||
assert mock_rel_dlist.call_count == 2
|
assert mock_rel_dlist.call_count == 3
|
||||||
mock_rel_dlist.assert_any_call(456) # src handle
|
mock_rel_dlist.assert_any_call(456) # src handle
|
||||||
mock_rel_dlist.assert_any_call(789) # dst handle
|
mock_rel_dlist.assert_any_call(789) # dst handle
|
||||||
mock_rem_agent.assert_called_once_with("agent1")
|
mock_rem_agent.assert_called_once_with("agent1")
|
||||||
|
|||||||
@ -36,7 +36,6 @@ from vllm.distributed.parallel_state import (
|
|||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
get_tp_group,
|
get_tp_group,
|
||||||
)
|
)
|
||||||
from vllm.distributed.utils import divide
|
|
||||||
from vllm.forward_context import ForwardContext
|
from vllm.forward_context import ForwardContext
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -513,6 +512,88 @@ class NixlConnectorScheduler:
|
|||||||
class NixlConnectorWorker:
|
class NixlConnectorWorker:
|
||||||
"""Implementation of Worker side methods"""
|
"""Implementation of Worker side methods"""
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class KVInfo:
|
||||||
|
tp_size: int
|
||||||
|
tp_rank: int
|
||||||
|
remote_tp_size: dict[EngineId, int]
|
||||||
|
is_mla: bool
|
||||||
|
total_num_kv_heads: int
|
||||||
|
|
||||||
|
def tp_ratio(
|
||||||
|
self,
|
||||||
|
remote_engine_id: Optional[EngineId] = None,
|
||||||
|
remote_tp_size: Optional[int] = None,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Calculate the tensor parallel ratio between local and remote TP.
|
||||||
|
We can think of it as the number of local TP workers-per-remote TP
|
||||||
|
workers. Local workers will read from the same remote TP worker in
|
||||||
|
groups of size `tp_ratio`. If remote tp_size > local tp_size, the
|
||||||
|
ratio is flipped (remote_size/local_size) and the returned value is
|
||||||
|
negative.
|
||||||
|
"""
|
||||||
|
if remote_tp_size is None:
|
||||||
|
assert remote_engine_id is not None
|
||||||
|
remote_tp_size = self.remote_tp_size[remote_engine_id]
|
||||||
|
if self.tp_size >= remote_tp_size:
|
||||||
|
assert self.tp_size % remote_tp_size == 0, (
|
||||||
|
f"Local tensor parallel size {self.tp_size} is not divisible "
|
||||||
|
f"by remote tensor parallel size {remote_tp_size}."
|
||||||
|
)
|
||||||
|
return self.tp_size // remote_tp_size
|
||||||
|
else:
|
||||||
|
assert remote_tp_size % self.tp_size == 0, (
|
||||||
|
f"Remote tensor parallel size {remote_tp_size} is not divisible "
|
||||||
|
f"by local tensor parallel size {self.tp_size}."
|
||||||
|
)
|
||||||
|
# P TP > D TP case, return the ratio as negative
|
||||||
|
return -remote_tp_size // self.tp_size
|
||||||
|
|
||||||
|
def is_kv_replicated(
|
||||||
|
self, engine_id: Optional[EngineId] = None, tp_size: Optional[int] = None
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Whether the KV cache is replicated across TP workers due to the
|
||||||
|
number of TP workers being greater than the number of KV heads.
|
||||||
|
"""
|
||||||
|
if tp_size is None:
|
||||||
|
assert engine_id is not None
|
||||||
|
tp_size = self.remote_tp_size[engine_id]
|
||||||
|
return tp_size // self.total_num_kv_heads >= 1
|
||||||
|
|
||||||
|
def replicates_kv_cache(
|
||||||
|
self,
|
||||||
|
remote_engine_id: Optional[EngineId] = None,
|
||||||
|
remote_tp_size: Optional[int] = None,
|
||||||
|
) -> bool:
|
||||||
|
# MLA is always replicated as the hidden dim can't be split.
|
||||||
|
return self.is_mla or self.is_kv_replicated(
|
||||||
|
remote_engine_id, remote_tp_size
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_target_remote_ranks(
|
||||||
|
self,
|
||||||
|
remote_engine_id: Optional[EngineId] = None,
|
||||||
|
remote_tp_size: Optional[int] = None,
|
||||||
|
) -> list[int]:
|
||||||
|
"""
|
||||||
|
Get the remote TP rank (on P) that the current local TP rank
|
||||||
|
(on D) will read from. When remote tp_size > local tp_size, we
|
||||||
|
read from multiple remote ranks.
|
||||||
|
"""
|
||||||
|
tp_ratio = self.tp_ratio(remote_engine_id, remote_tp_size)
|
||||||
|
if tp_ratio > 0:
|
||||||
|
return [self.tp_rank // tp_ratio]
|
||||||
|
else:
|
||||||
|
# P TP > D TP case, D reads from |tp_ratio| remote workers.
|
||||||
|
tp_ratio = -tp_ratio
|
||||||
|
if self.replicates_kv_cache(remote_engine_id, remote_tp_size):
|
||||||
|
# When cache is replicated on remote, we only need to read
|
||||||
|
# from one remote (they all have the same cache).
|
||||||
|
return [self.tp_rank * tp_ratio]
|
||||||
|
return [self.tp_rank * tp_ratio + i for i in range(tp_ratio)]
|
||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
||||||
if NixlWrapper is None:
|
if NixlWrapper is None:
|
||||||
logger.error("NIXL is not available")
|
logger.error("NIXL is not available")
|
||||||
@ -601,8 +682,10 @@ class NixlConnectorWorker:
|
|||||||
self.copy_blocks: CopyBlocksOp | None = None
|
self.copy_blocks: CopyBlocksOp | None = None
|
||||||
|
|
||||||
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
|
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
|
||||||
# rank will still only pull from a single remote TP worker.
|
# rank may pull from multiple remote TP workers.
|
||||||
self.kv_caches_base_addr: dict[EngineId, list[int]] = {}
|
self.kv_caches_base_addr: defaultdict[EngineId, dict[int, list[int]]] = (
|
||||||
|
defaultdict(dict)
|
||||||
|
)
|
||||||
|
|
||||||
# Number of NIXL regions. Currently one region per cache
|
# Number of NIXL regions. Currently one region per cache
|
||||||
# (so 1 per layer for MLA, otherwise 2 per layer)
|
# (so 1 per layer for MLA, otherwise 2 per layer)
|
||||||
@ -611,8 +694,13 @@ class NixlConnectorWorker:
|
|||||||
|
|
||||||
# nixl_prepped_dlist_handle.
|
# nixl_prepped_dlist_handle.
|
||||||
self.src_xfer_side_handle: int = 0
|
self.src_xfer_side_handle: int = 0
|
||||||
|
# Populated dynamically during handshake based on remote configuration.
|
||||||
|
# Keep track of regions at different tp_ratio values. tp_ratio->handles
|
||||||
|
self.src_xfer_side_chunked_handles: dict[int, list[int]] = {}
|
||||||
# Map of engine_id -> nixl_prepped_dlist_handle (int)].
|
# Map of engine_id -> nixl_prepped_dlist_handle (int)].
|
||||||
self.dst_xfer_side_handles: dict[EngineId, int] = {}
|
self.dst_xfer_side_handles: defaultdict[EngineId, dict[int, int]] = defaultdict(
|
||||||
|
dict
|
||||||
|
)
|
||||||
|
|
||||||
# Map of engine_id -> num_blocks. All ranks in the same deployment will
|
# Map of engine_id -> num_blocks. All ranks in the same deployment will
|
||||||
# have the same number of blocks.
|
# have the same number of blocks.
|
||||||
@ -646,7 +734,6 @@ class NixlConnectorWorker:
|
|||||||
# Protects _handshake_futures and _remote_agents.
|
# Protects _handshake_futures and _remote_agents.
|
||||||
self._handshake_lock = threading.RLock()
|
self._handshake_lock = threading.RLock()
|
||||||
|
|
||||||
self.vllm_config = vllm_config
|
|
||||||
self.block_size = vllm_config.cache_config.block_size
|
self.block_size = vllm_config.cache_config.block_size
|
||||||
self.model_config = vllm_config.model_config
|
self.model_config = vllm_config.model_config
|
||||||
self.cache_config = vllm_config.cache_config
|
self.cache_config = vllm_config.cache_config
|
||||||
@ -678,6 +765,14 @@ class NixlConnectorWorker:
|
|||||||
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
|
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
|
||||||
self.xfer_stats = NixlKVConnectorStats()
|
self.xfer_stats = NixlKVConnectorStats()
|
||||||
|
|
||||||
|
self.kv_info = self.KVInfo(
|
||||||
|
tp_size=self.world_size,
|
||||||
|
tp_rank=self.tp_rank,
|
||||||
|
remote_tp_size=self._tp_size, # shared state
|
||||||
|
is_mla=self.use_mla,
|
||||||
|
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _nixl_handshake_listener(
|
def _nixl_handshake_listener(
|
||||||
metadata: NixlAgentMetadata,
|
metadata: NixlAgentMetadata,
|
||||||
@ -717,52 +812,53 @@ class NixlConnectorWorker:
|
|||||||
|
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
# NOTE(rob): we need each rank to have a unique port. This is
|
# When target instance TP > local TP, we need to perform multiple
|
||||||
# a hack to keep us moving. We will switch when moving to etcd
|
# handshakes. Do it in a single background job for simplicity.
|
||||||
# or where we have a single ZMQ socket in the scheduler.
|
# Regardless, only handshake with the remote TP rank(s) that current
|
||||||
|
# local rank will read from. Note that With homogeneous TP,
|
||||||
# Handshake only with the remote TP rank that current local rank will
|
# this happens to be the same single rank_i.
|
||||||
# pull from. With homogeneous TP it happens to be the same rank_i.
|
p_remote_ranks = self.kv_info.get_target_remote_ranks(
|
||||||
tp_ratio = self._tp_size[self.engine_id] // remote_tp_size
|
remote_tp_size=remote_tp_size
|
||||||
p_remote_rank = self.tp_rank // tp_ratio
|
|
||||||
path = make_zmq_path("tcp", host, port + p_remote_rank)
|
|
||||||
logger.debug(
|
|
||||||
"Querying metadata on path: %s at remote rank %s", path, p_remote_rank
|
|
||||||
)
|
)
|
||||||
|
remote_rank_to_agent_name = {}
|
||||||
# Send query for the request.
|
for remote_rank in p_remote_ranks:
|
||||||
with zmq_ctx(zmq.REQ, path) as sock:
|
path = make_zmq_path("tcp", host, port + remote_rank)
|
||||||
# Set receive timeout to 5 seconds to avoid hanging on dead server
|
logger.warning(
|
||||||
sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds
|
"Querying metadata on path: %s at remote rank %s", path, remote_rank
|
||||||
sock.send(GET_META_MSG)
|
|
||||||
metadata_bytes = sock.recv()
|
|
||||||
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
|
|
||||||
metadata = decoder.decode(metadata_bytes)
|
|
||||||
got_metadata_time = time.perf_counter()
|
|
||||||
logger.debug(
|
|
||||||
"NIXL handshake: get metadata took: %s", got_metadata_time - start_time
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ensure engine id matches.
|
# Send query for the request.
|
||||||
if metadata.engine_id != expected_engine_id:
|
with zmq_ctx(zmq.REQ, path) as sock:
|
||||||
raise RuntimeError(
|
# Set receive timeout to 5 seconds to avoid hanging on dead server
|
||||||
f"Remote NIXL agent engine ID mismatch. "
|
sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds
|
||||||
f"Expected {expected_engine_id},"
|
sock.send(GET_META_MSG)
|
||||||
f"received {metadata.engine_id}."
|
metadata_bytes = sock.recv()
|
||||||
|
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
|
||||||
|
metadata = decoder.decode(metadata_bytes)
|
||||||
|
got_metadata_time = time.perf_counter()
|
||||||
|
logger.debug(
|
||||||
|
"NIXL handshake: get metadata took: %s", got_metadata_time - start_time
|
||||||
)
|
)
|
||||||
|
|
||||||
# Register Remote agent.
|
# Ensure engine id matches.
|
||||||
remote_agent_name = self.add_remote_agent(
|
if metadata.engine_id != expected_engine_id:
|
||||||
metadata, p_remote_rank, remote_tp_size
|
raise RuntimeError(
|
||||||
)
|
f"Remote NIXL agent engine ID mismatch. "
|
||||||
setup_agent_time = time.perf_counter()
|
f"Expected {expected_engine_id},"
|
||||||
logger.debug(
|
f"received {metadata.engine_id}."
|
||||||
"NIXL handshake: add agent took: %s",
|
)
|
||||||
setup_agent_time - got_metadata_time,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Remote rank -> agent name.
|
# Register Remote agent.
|
||||||
return {p_remote_rank: remote_agent_name}
|
remote_agent_name = self.add_remote_agent(
|
||||||
|
metadata, remote_rank, remote_tp_size
|
||||||
|
)
|
||||||
|
setup_agent_time = time.perf_counter()
|
||||||
|
logger.debug(
|
||||||
|
"NIXL handshake: add agent took: %s",
|
||||||
|
setup_agent_time - got_metadata_time,
|
||||||
|
)
|
||||||
|
remote_rank_to_agent_name[remote_rank] = remote_agent_name
|
||||||
|
return remote_rank_to_agent_name
|
||||||
|
|
||||||
def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> None:
|
def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> None:
|
||||||
"""
|
"""
|
||||||
@ -916,7 +1012,7 @@ class NixlConnectorWorker:
|
|||||||
assert len(self.block_len_per_layer) == len(seen_base_addresses)
|
assert len(self.block_len_per_layer) == len(seen_base_addresses)
|
||||||
assert self.num_blocks != 0
|
assert self.num_blocks != 0
|
||||||
|
|
||||||
self.kv_caches_base_addr[self.engine_id] = seen_base_addresses
|
self.kv_caches_base_addr[self.engine_id][self.tp_rank] = seen_base_addresses
|
||||||
self.num_regions = len(caches_data)
|
self.num_regions = len(caches_data)
|
||||||
self.num_layers = len(xfer_buffers.keys())
|
self.num_layers = len(xfer_buffers.keys())
|
||||||
|
|
||||||
@ -942,7 +1038,7 @@ class NixlConnectorWorker:
|
|||||||
self.num_regions *= 2
|
self.num_regions *= 2
|
||||||
|
|
||||||
# Register local/src descr for NIXL xfer.
|
# Register local/src descr for NIXL xfer.
|
||||||
blocks_data = []
|
self.src_blocks_data = []
|
||||||
for i, base_addr in enumerate(seen_base_addresses):
|
for i, base_addr in enumerate(seen_base_addresses):
|
||||||
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
|
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
|
||||||
# NOTE With heter-TP, more blocks are prepared than what are
|
# NOTE With heter-TP, more blocks are prepared than what are
|
||||||
@ -954,7 +1050,7 @@ class NixlConnectorWorker:
|
|||||||
block_offset = block_id * self.block_len_per_layer[i]
|
block_offset = block_id * self.block_len_per_layer[i]
|
||||||
addr = base_addr + block_offset
|
addr = base_addr + block_offset
|
||||||
# (addr, len, device id)
|
# (addr, len, device id)
|
||||||
blocks_data.append((addr, kv_block_len, self.tp_rank))
|
self.src_blocks_data.append((addr, kv_block_len, self.tp_rank))
|
||||||
|
|
||||||
if self._use_flashinfer:
|
if self._use_flashinfer:
|
||||||
# Separate and interleave K/V regions to maintain the same
|
# Separate and interleave K/V regions to maintain the same
|
||||||
@ -965,15 +1061,17 @@ class NixlConnectorWorker:
|
|||||||
addr = base_addr + block_offset
|
addr = base_addr + block_offset
|
||||||
# Register addresses for V cache (K registered first).
|
# Register addresses for V cache (K registered first).
|
||||||
v_addr = addr + kv_block_len
|
v_addr = addr + kv_block_len
|
||||||
blocks_data.append((v_addr, kv_block_len, self.tp_rank))
|
self.src_blocks_data.append((v_addr, kv_block_len, self.tp_rank))
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Created %s blocks for src engine %s and rank %s",
|
"Created %s blocks for src engine %s and rank %s",
|
||||||
len(blocks_data),
|
len(self.src_blocks_data),
|
||||||
self.engine_id,
|
self.engine_id,
|
||||||
self.tp_rank,
|
self.tp_rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
|
descs = self.nixl_wrapper.get_xfer_descs(
|
||||||
|
self.src_blocks_data, self.nixl_memory_type
|
||||||
|
)
|
||||||
# NIXL_INIT_AGENT to be used for preparations of local descs.
|
# NIXL_INIT_AGENT to be used for preparations of local descs.
|
||||||
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
|
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
|
||||||
"NIXL_INIT_AGENT", descs
|
"NIXL_INIT_AGENT", descs
|
||||||
@ -981,13 +1079,11 @@ class NixlConnectorWorker:
|
|||||||
|
|
||||||
# TODO(mgoin): Hybrid memory allocator is currently disabled for
|
# TODO(mgoin): Hybrid memory allocator is currently disabled for
|
||||||
# models with local attention (Llama 4). Can remove this once enabled.
|
# models with local attention (Llama 4). Can remove this once enabled.
|
||||||
if self.vllm_config.model_config.hf_config.model_type == "llama4":
|
if self.model_config.hf_config.model_type == "llama4":
|
||||||
from transformers import Llama4TextConfig
|
from transformers import Llama4TextConfig
|
||||||
|
|
||||||
assert isinstance(
|
assert isinstance(self.model_config.hf_text_config, Llama4TextConfig)
|
||||||
self.vllm_config.model_config.hf_text_config, Llama4TextConfig
|
llama4_config = self.model_config.hf_text_config
|
||||||
)
|
|
||||||
llama4_config = self.vllm_config.model_config.hf_text_config
|
|
||||||
no_rope_layers = llama4_config.no_rope_layers
|
no_rope_layers = llama4_config.no_rope_layers
|
||||||
chunk_size = llama4_config.attention_chunk_size
|
chunk_size = llama4_config.attention_chunk_size
|
||||||
chunk_block_size = math.ceil(chunk_size / self.block_size)
|
chunk_block_size = math.ceil(chunk_size / self.block_size)
|
||||||
@ -1007,7 +1103,7 @@ class NixlConnectorWorker:
|
|||||||
metadata = NixlAgentMetadata(
|
metadata = NixlAgentMetadata(
|
||||||
engine_id=self.engine_id,
|
engine_id=self.engine_id,
|
||||||
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
|
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
|
||||||
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
|
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id][self.tp_rank],
|
||||||
num_blocks=self.num_blocks,
|
num_blocks=self.num_blocks,
|
||||||
block_lens=self.block_len_per_layer,
|
block_lens=self.block_len_per_layer,
|
||||||
attn_backend_name=self.backend_name,
|
attn_backend_name=self.backend_name,
|
||||||
@ -1035,10 +1131,12 @@ class NixlConnectorWorker:
|
|||||||
|
|
||||||
In particular, handle both homogeneous and heterogeneous TP. The former
|
In particular, handle both homogeneous and heterogeneous TP. The former
|
||||||
requires local rank_i to read from remote rank_i.
|
requires local rank_i to read from remote rank_i.
|
||||||
The latter, assuming D.world_size > P.world_size, requires that two or
|
The latter, in the case of D.world_size < P.world_size, requires that a
|
||||||
more local TP worker share the xfer from a single TP worker.
|
local (D) TP worker reads from multiple remote (P) TP workers.
|
||||||
|
Conversely, assuming D.world_size > P.world_size, two or more local TP
|
||||||
|
workers will read from a single remote TP worker.
|
||||||
|
|
||||||
Here's an example (non-MLA case):
|
Here's an example for the last case described above (non-MLA):
|
||||||
|
|
||||||
rank_offset p_remote_tp_rank
|
rank_offset p_remote_tp_rank
|
||||||
(kv split no)
|
(kv split no)
|
||||||
@ -1070,107 +1168,91 @@ class NixlConnectorWorker:
|
|||||||
engine_id = nixl_agent_meta.engine_id
|
engine_id = nixl_agent_meta.engine_id
|
||||||
# TODO re-evaluate refreshing for scaling/recovery
|
# TODO re-evaluate refreshing for scaling/recovery
|
||||||
if remote_tp_rank in self._remote_agents.get(engine_id, {}):
|
if remote_tp_rank in self._remote_agents.get(engine_id, {}):
|
||||||
|
logger.warning(
|
||||||
|
"Remote agent with engine_id %s and rank"
|
||||||
|
"%s already exchanged metadata, skip handshake.",
|
||||||
|
engine_id,
|
||||||
|
remote_tp_rank,
|
||||||
|
)
|
||||||
return self._remote_agents[engine_id][remote_tp_rank]
|
return self._remote_agents[engine_id][remote_tp_rank]
|
||||||
|
|
||||||
|
### Register remote agent metadata
|
||||||
if engine_id not in self._tp_size:
|
if engine_id not in self._tp_size:
|
||||||
self._tp_size[engine_id] = remote_tp_size
|
self._tp_size[engine_id] = remote_tp_size
|
||||||
else:
|
|
||||||
assert self._tp_size[engine_id] == remote_tp_size
|
|
||||||
# TODO We may eventually want to skip enforcing the same attn backend.
|
|
||||||
assert nixl_agent_meta.attn_backend_name == self.backend_name
|
|
||||||
|
|
||||||
remote_agent_name = self.nixl_wrapper.add_remote_agent(
|
remote_agent_name = self.nixl_wrapper.add_remote_agent(
|
||||||
nixl_agent_meta.agent_metadata
|
nixl_agent_meta.agent_metadata
|
||||||
)
|
)
|
||||||
|
|
||||||
# Number of D TP workers reading from a single P TP worker. This is
|
# Create dst descs and xfer side handles. TP workers have same #blocks
|
||||||
# 1 when P and D `--tensor-parallel-size` match.
|
# so we only register once per engine_id.
|
||||||
tp_ratio = divide(self._tp_size[self.engine_id], self._tp_size[engine_id])
|
if engine_id not in self.dst_num_blocks:
|
||||||
assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP"
|
|
||||||
assert not self._use_pallas or tp_ratio == 1, (
|
|
||||||
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle tp_size>num_kv_heads: replicate KV cache.
|
|
||||||
total_num_kv_heads = self.model_config.get_total_num_kv_heads()
|
|
||||||
is_kv_replicated = self._tp_size[engine_id] // total_num_kv_heads >= 1
|
|
||||||
|
|
||||||
remote_block_len = nixl_agent_meta.block_lens[0]
|
|
||||||
if nixl_agent_meta.kv_cache_layout != self.kv_cache_layout:
|
|
||||||
if (
|
|
||||||
self.vllm_config.kv_transfer_config is not None
|
|
||||||
and self.vllm_config.kv_transfer_config.enable_permute_local_kv
|
|
||||||
and nixl_agent_meta.kv_cache_layout == "HND"
|
|
||||||
):
|
|
||||||
logger.info(
|
|
||||||
"Remote is HND and local is NHD, enabled additional permute "
|
|
||||||
"on local device KV."
|
|
||||||
)
|
|
||||||
self.enable_permute_local_kv = True
|
|
||||||
else:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Heterogeneous TP expects same kv_cache_layout. "
|
|
||||||
"Or enable experimental feature to use HND to NHD support by "
|
|
||||||
"setting 'enable_permute_local_kv'=True in --kv-transfer-config."
|
|
||||||
)
|
|
||||||
if self.use_mla or is_kv_replicated:
|
|
||||||
# With replicated KV cache, only the number of blocks can differ.
|
|
||||||
assert self.block_len_per_layer == nixl_agent_meta.block_lens, (
|
|
||||||
"KV cache sizes must match between P and D when replicated"
|
|
||||||
)
|
|
||||||
remote_block_size = remote_block_len // (self.slot_size_per_layer[0])
|
|
||||||
else:
|
|
||||||
# When MLA is not used, this is a list of the same block length
|
|
||||||
for block_len in nixl_agent_meta.block_lens:
|
|
||||||
assert block_len == remote_block_len, (
|
|
||||||
"All remote layers must have the same block size"
|
|
||||||
)
|
|
||||||
remote_block_size = remote_block_len // (
|
|
||||||
self.slot_size_per_layer[0] * tp_ratio
|
|
||||||
)
|
|
||||||
if self._use_flashinfer:
|
|
||||||
# With flashinfer, KV are sent in the same message.
|
|
||||||
remote_block_size //= 2
|
|
||||||
if tp_ratio > 1:
|
|
||||||
# Heterogeneous TP expects same kv_cache_layout.
|
|
||||||
if nixl_agent_meta.kv_cache_layout == "NHD":
|
|
||||||
raise ValueError(
|
|
||||||
"Heterogeneous TP is not supported for remote with NHD."
|
|
||||||
)
|
|
||||||
if self.device_type == "xpu":
|
|
||||||
raise ValueError("Heterogeneous TP is not supported on XPU")
|
|
||||||
|
|
||||||
assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, (
|
|
||||||
"Remote P worker KV layer cache must be of shape [2, N, "
|
|
||||||
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
|
|
||||||
)
|
|
||||||
|
|
||||||
assert self.block_size == remote_block_size, (
|
|
||||||
"Remote P worker with different page/block size is not supported "
|
|
||||||
f"{self.block_size=}, {remote_block_size=}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create dst descs and xfer side handles. TP workers have same #blocks.
|
|
||||||
if engine_id in self.dst_num_blocks:
|
|
||||||
assert self.dst_num_blocks[engine_id] == nixl_agent_meta.num_blocks
|
|
||||||
else:
|
|
||||||
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
|
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
|
||||||
|
|
||||||
|
# Keep track of remote agent kv caches base addresses.
|
||||||
|
self.kv_caches_base_addr[engine_id][remote_tp_rank] = (
|
||||||
|
nixl_agent_meta.kv_caches_base_addr
|
||||||
|
)
|
||||||
|
self._validate_remote_agent_handshake(nixl_agent_meta, remote_tp_size)
|
||||||
|
|
||||||
|
# This is 1 when P and D `--tensor-parallel-size` match. Otherwise,
|
||||||
|
# this is the ratio between the two sizes.
|
||||||
|
tp_ratio = self.kv_info.tp_ratio(engine_id)
|
||||||
|
|
||||||
|
# Handle tp_size>num_kv_heads: replicate KV cache.
|
||||||
|
indexes_into_remote = (
|
||||||
|
not self.kv_info.replicates_kv_cache(engine_id) and tp_ratio > 0
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Registering remote agent (%s, rank %s) memory regions with tp_ratio %s",
|
||||||
|
engine_id,
|
||||||
|
remote_tp_rank,
|
||||||
|
tp_ratio,
|
||||||
|
)
|
||||||
|
|
||||||
|
### (Optional) Register local agent memory regions.
|
||||||
|
# MLA-optimization: only prepare one region.
|
||||||
|
if (
|
||||||
|
tp_ratio < 0
|
||||||
|
and not self.use_mla
|
||||||
|
and tp_ratio not in self.src_xfer_side_chunked_handles
|
||||||
|
):
|
||||||
|
# Remote tp_size > local tp_size: read from multiple remote ranks.
|
||||||
|
# Logically "split" own regions into |tp_ratio| chunks. Mind that
|
||||||
|
# we only do this once per remote tp_size (replica-friendly).
|
||||||
|
self.src_xfer_side_chunked_handles[tp_ratio] = []
|
||||||
|
for i in range(-tp_ratio):
|
||||||
|
blocks_data = []
|
||||||
|
for memory_region in self.src_blocks_data:
|
||||||
|
addr, local_block_len, own_tp_rank = memory_region
|
||||||
|
# Computing block len layer by layer allows for different
|
||||||
|
# block sizes to be used.
|
||||||
|
remote_block_len = local_block_len // (-tp_ratio)
|
||||||
|
addr = addr + i * remote_block_len
|
||||||
|
blocks_data.append((addr, remote_block_len, own_tp_rank))
|
||||||
|
descs = self.nixl_wrapper.get_xfer_descs(
|
||||||
|
blocks_data, self.nixl_memory_type
|
||||||
|
)
|
||||||
|
handle = self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs)
|
||||||
|
self.src_xfer_side_chunked_handles[tp_ratio].append(handle)
|
||||||
|
|
||||||
|
### Register remote agent memory regions
|
||||||
blocks_data = []
|
blocks_data = []
|
||||||
# With homogeneous TP, D pulls the whole kv cache from corresponding
|
# With homogeneous TP, D pulls the whole kv cache from corresponding
|
||||||
# rank. With heterogeneous TP, prepare the descriptors by splitting the
|
# rank. With heterogeneous TP, prepare the descriptors by splitting the
|
||||||
# P KV cache along kv_head dim, of D worker's kv_head size (D>P).
|
# P KV cache along kv_head dim, of D worker's kv_head size (D>P).
|
||||||
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
|
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
|
||||||
self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr
|
|
||||||
|
|
||||||
assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer)
|
|
||||||
# Register all remote blocks, but only the corresponding kv heads.
|
# Register all remote blocks, but only the corresponding kv heads.
|
||||||
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
|
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
|
||||||
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
|
# Read our whole local region size from remote.
|
||||||
|
local_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
|
||||||
|
if tp_ratio < 0 and not self.use_mla:
|
||||||
|
# Remote tp is bigger: read a chunk of local region from remote
|
||||||
|
local_block_len = local_block_len // (-tp_ratio)
|
||||||
rank_offset = (
|
rank_offset = (
|
||||||
self.tp_rank % tp_ratio * kv_block_len
|
self.tp_rank % tp_ratio * local_block_len if indexes_into_remote else 0
|
||||||
if not (self.use_mla or is_kv_replicated)
|
|
||||||
else 0
|
|
||||||
)
|
)
|
||||||
for block_id in range(nixl_agent_meta.num_blocks):
|
for block_id in range(nixl_agent_meta.num_blocks):
|
||||||
block_offset = block_id * nixl_agent_meta.block_lens[i]
|
block_offset = block_id * nixl_agent_meta.block_lens[i]
|
||||||
@ -1179,7 +1261,7 @@ class NixlConnectorWorker:
|
|||||||
# self.block_len == remote_block_len//tp_ratio bytes.
|
# self.block_len == remote_block_len//tp_ratio bytes.
|
||||||
addr = base_addr + block_offset + rank_offset
|
addr = base_addr + block_offset + rank_offset
|
||||||
# (addr, len, device id)
|
# (addr, len, device id)
|
||||||
blocks_data.append((addr, kv_block_len, remote_tp_rank))
|
blocks_data.append((addr, local_block_len, remote_tp_rank))
|
||||||
|
|
||||||
if self._use_flashinfer:
|
if self._use_flashinfer:
|
||||||
# With FlashInfer index V separately to allow head splitting.
|
# With FlashInfer index V separately to allow head splitting.
|
||||||
@ -1187,7 +1269,7 @@ class NixlConnectorWorker:
|
|||||||
block_offset = block_id * nixl_agent_meta.block_lens[i]
|
block_offset = block_id * nixl_agent_meta.block_lens[i]
|
||||||
addr = base_addr + block_offset + rank_offset
|
addr = base_addr + block_offset + rank_offset
|
||||||
v_addr = addr + nixl_agent_meta.block_lens[i] // 2
|
v_addr = addr + nixl_agent_meta.block_lens[i] // 2
|
||||||
blocks_data.append((v_addr, kv_block_len, remote_tp_rank))
|
blocks_data.append((v_addr, local_block_len, remote_tp_rank))
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Created %s blocks for dst engine %s with remote rank %s and local rank %s",
|
"Created %s blocks for dst engine %s with remote rank %s and local rank %s",
|
||||||
@ -1199,12 +1281,87 @@ class NixlConnectorWorker:
|
|||||||
|
|
||||||
# Register with NIXL.
|
# Register with NIXL.
|
||||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
|
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
|
||||||
self.dst_xfer_side_handles[engine_id] = self.nixl_wrapper.prep_xfer_dlist(
|
self.dst_xfer_side_handles[engine_id][remote_tp_rank] = (
|
||||||
remote_agent_name, descs
|
self.nixl_wrapper.prep_xfer_dlist(remote_agent_name, descs)
|
||||||
)
|
)
|
||||||
|
|
||||||
return remote_agent_name
|
return remote_agent_name
|
||||||
|
|
||||||
|
def _validate_remote_agent_handshake(
|
||||||
|
self, nixl_agent_meta: NixlAgentMetadata, remote_tp_size: int
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Validate the remote agent handshake metadata ensuring the
|
||||||
|
invariants hold true.
|
||||||
|
"""
|
||||||
|
remote_engine_id = nixl_agent_meta.engine_id
|
||||||
|
|
||||||
|
assert self._tp_size[remote_engine_id] == remote_tp_size
|
||||||
|
# TODO We may eventually want to skip enforcing the same attn backend.
|
||||||
|
assert nixl_agent_meta.attn_backend_name == self.backend_name
|
||||||
|
assert nixl_agent_meta.kv_cache_layout == self.kv_cache_layout
|
||||||
|
|
||||||
|
tp_ratio = self.kv_info.tp_ratio(remote_engine_id)
|
||||||
|
assert not self._use_pallas or tp_ratio == 1, (
|
||||||
|
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
|
||||||
|
)
|
||||||
|
# Num kv_heads > tp_size and P TP > D TP case, not supported
|
||||||
|
assert not (tp_ratio < 0 and self.kv_info.is_kv_replicated(remote_engine_id))
|
||||||
|
|
||||||
|
# Block len can only vary across layers when using MLA.
|
||||||
|
remote_block_len = nixl_agent_meta.block_lens[0]
|
||||||
|
if self.kv_info.replicates_kv_cache(remote_engine_id):
|
||||||
|
# With replicated KV cache, only the number of blocks can differ.
|
||||||
|
assert self.block_len_per_layer == nixl_agent_meta.block_lens, (
|
||||||
|
"KV cache sizes must match between P and D when replicated"
|
||||||
|
)
|
||||||
|
remote_block_size = remote_block_len // (self.slot_size_per_layer[0])
|
||||||
|
else:
|
||||||
|
if tp_ratio != 1 and self.device_type == "xpu":
|
||||||
|
# XPU uses NHD, hence it does not support splitting on H
|
||||||
|
raise ValueError("Heterogeneous TP is not supported on XPU")
|
||||||
|
# When MLA is not used, this is a list of the same block length
|
||||||
|
for block_len in nixl_agent_meta.block_lens:
|
||||||
|
assert block_len == remote_block_len, (
|
||||||
|
"All remote layers must have the same block size"
|
||||||
|
)
|
||||||
|
|
||||||
|
if tp_ratio > 0:
|
||||||
|
# Remote NHD/H'D*tp_ratio=N -page_size-
|
||||||
|
remote_block_size = remote_block_len // (
|
||||||
|
self.slot_size_per_layer[0] * tp_ratio
|
||||||
|
)
|
||||||
|
# Remote tp is smaller: remote block_len size is bigger
|
||||||
|
assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, (
|
||||||
|
"Remote P worker KV layer cache must be of shape [2, N, "
|
||||||
|
"local_kv_heads*tp_ratio, page_size, head_dim] and same dtype."
|
||||||
|
) # noqa: E501
|
||||||
|
else:
|
||||||
|
# Remote NHD/(H'D/tp_ratio)=N -page_size-
|
||||||
|
remote_block_size = remote_block_len // (
|
||||||
|
self.slot_size_per_layer[0] // (-tp_ratio)
|
||||||
|
)
|
||||||
|
# Remote tp is bigger: remote block_len size is smaller
|
||||||
|
assert remote_block_len == self.block_len_per_layer[0] // (-tp_ratio), (
|
||||||
|
"Remote P worker KV layer cache must be of shape [2, N, "
|
||||||
|
"local_kv_heads/tp_ratio, page_size, head_dim] and same dtype."
|
||||||
|
) # noqa: E501
|
||||||
|
|
||||||
|
if self._use_flashinfer:
|
||||||
|
# With flashinfer, KV are sent in the same message.
|
||||||
|
remote_block_size //= 2
|
||||||
|
|
||||||
|
# We may allow it in the future with logical kvcache manager block_size
|
||||||
|
assert self.block_size == remote_block_size, (
|
||||||
|
"Remote P worker with different page/block size is not supported "
|
||||||
|
f"{self.block_size=}, {remote_block_size=}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# TP workers (handhshakes with same remote) have same #blocks.
|
||||||
|
assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks
|
||||||
|
# Same number of regions/~layers.
|
||||||
|
assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer)
|
||||||
|
|
||||||
def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta):
|
def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta):
|
||||||
"""copy recved kv from host buffer to device."""
|
"""copy recved kv from host buffer to device."""
|
||||||
assert self.use_host_buffer
|
assert self.use_host_buffer
|
||||||
@ -1384,7 +1541,7 @@ class NixlConnectorWorker:
|
|||||||
"""
|
"""
|
||||||
done_req_ids: set[str] = set()
|
done_req_ids: set[str] = set()
|
||||||
for req_id, handles in list(transfers.items()):
|
for req_id, handles in list(transfers.items()):
|
||||||
in_progress = False
|
in_progress = []
|
||||||
for handle, _xfer_stime in handles:
|
for handle, _xfer_stime in handles:
|
||||||
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
|
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
|
||||||
if xfer_state == "DONE":
|
if xfer_state == "DONE":
|
||||||
@ -1393,7 +1550,7 @@ class NixlConnectorWorker:
|
|||||||
self.xfer_stats.record_transfer(res)
|
self.xfer_stats.record_transfer(res)
|
||||||
self.nixl_wrapper.release_xfer_handle(handle)
|
self.nixl_wrapper.release_xfer_handle(handle)
|
||||||
elif xfer_state == "PROC":
|
elif xfer_state == "PROC":
|
||||||
in_progress = True
|
in_progress.append((handle, _xfer_stime))
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
# transfer failed - mark blocks as invalid
|
# transfer failed - mark blocks as invalid
|
||||||
@ -1410,8 +1567,11 @@ class NixlConnectorWorker:
|
|||||||
self.nixl_wrapper.release_xfer_handle(handle)
|
self.nixl_wrapper.release_xfer_handle(handle)
|
||||||
self.xfer_stats.record_failed_transfer()
|
self.xfer_stats.record_failed_transfer()
|
||||||
if not in_progress:
|
if not in_progress:
|
||||||
|
# Only report request as completed when all transfers are done.
|
||||||
done_req_ids.add(req_id)
|
done_req_ids.add(req_id)
|
||||||
del transfers[req_id]
|
del transfers[req_id]
|
||||||
|
else:
|
||||||
|
transfers[req_id] = in_progress
|
||||||
return done_req_ids
|
return done_req_ids
|
||||||
|
|
||||||
def start_load_kv(self, metadata: NixlConnectorMetadata):
|
def start_load_kv(self, metadata: NixlConnectorMetadata):
|
||||||
@ -1466,17 +1626,37 @@ class NixlConnectorWorker:
|
|||||||
self._reqs_to_send[req_id] = expiration_time
|
self._reqs_to_send[req_id] = expiration_time
|
||||||
|
|
||||||
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
|
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
|
||||||
logger.debug(
|
remote_ranks = self.kv_info.get_target_remote_ranks(meta.remote_engine_id)
|
||||||
"Remote agent %s available, calling _read_blocks for req %s",
|
tp_ratio = self.kv_info.tp_ratio(meta.remote_engine_id)
|
||||||
meta.remote_engine_id,
|
# D may have to perform multiple reads from different remote ranks.
|
||||||
req_id,
|
for i, remote_rank in enumerate(remote_ranks):
|
||||||
)
|
logger.debug(
|
||||||
self._read_blocks(
|
"Remote agent %s available, calling _read_blocks"
|
||||||
request_id=req_id,
|
" on remote rank %s for req %s",
|
||||||
dst_engine_id=meta.remote_engine_id,
|
meta.remote_engine_id,
|
||||||
local_block_ids=meta.local_block_ids,
|
remote_rank,
|
||||||
remote_block_ids=meta.remote_block_ids,
|
req_id,
|
||||||
)
|
)
|
||||||
|
if tp_ratio < 0 and not self.use_mla:
|
||||||
|
# Remote tp_size > local tp_size: we must perform multiple
|
||||||
|
# reads. Get the memory chunk onto which we will write to.
|
||||||
|
local_xfer_side_handle = self.src_xfer_side_chunked_handles[tp_ratio][i]
|
||||||
|
else:
|
||||||
|
# Single read from remote, we write to the whole memory region.
|
||||||
|
local_xfer_side_handle = self.src_xfer_side_handle
|
||||||
|
# Destination handle: remote_engine_id -> remote_rank -> handle.
|
||||||
|
remote_xfer_side_handle = self.dst_xfer_side_handles[meta.remote_engine_id][
|
||||||
|
remote_rank
|
||||||
|
]
|
||||||
|
self._read_blocks(
|
||||||
|
request_id=req_id,
|
||||||
|
dst_engine_id=meta.remote_engine_id,
|
||||||
|
local_block_ids=meta.local_block_ids,
|
||||||
|
remote_block_ids=meta.remote_block_ids,
|
||||||
|
remote_rank=remote_rank,
|
||||||
|
local_xfer_side_handle=local_xfer_side_handle,
|
||||||
|
remote_xfer_side_handle=remote_xfer_side_handle,
|
||||||
|
)
|
||||||
|
|
||||||
def _read_blocks(
|
def _read_blocks(
|
||||||
self,
|
self,
|
||||||
@ -1484,7 +1664,14 @@ class NixlConnectorWorker:
|
|||||||
remote_block_ids: list[int],
|
remote_block_ids: list[int],
|
||||||
dst_engine_id: str,
|
dst_engine_id: str,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
|
remote_rank: int,
|
||||||
|
local_xfer_side_handle: int,
|
||||||
|
remote_xfer_side_handle: int,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Post a READ xfer request from a single local worker to a single
|
||||||
|
remote worker.
|
||||||
|
"""
|
||||||
# NOTE(rob): having the staging blocks be on the READER side is
|
# NOTE(rob): having the staging blocks be on the READER side is
|
||||||
# not going to work well (since we will have to call rearrange tensors).
|
# not going to work well (since we will have to call rearrange tensors).
|
||||||
# after we detect the txn is complete (which means we cannot make the
|
# after we detect the txn is complete (which means we cannot make the
|
||||||
@ -1497,14 +1684,14 @@ class NixlConnectorWorker:
|
|||||||
|
|
||||||
# Number of D TP workers that will read from dst P. Propagate tp_ratio
|
# Number of D TP workers that will read from dst P. Propagate tp_ratio
|
||||||
# on notification so that dst worker can wait before freeing blocks.
|
# on notification so that dst worker can wait before freeing blocks.
|
||||||
tp_ratio = self._tp_size[self.engine_id] // self._tp_size[dst_engine_id]
|
# Cap to 1 when P TP > D TP: only a single rank will read from remote.
|
||||||
|
tp_ratio = max(1, self.kv_info.tp_ratio(dst_engine_id))
|
||||||
notif_id = f"{request_id}:{tp_ratio}".encode()
|
notif_id = f"{request_id}:{tp_ratio}".encode()
|
||||||
|
|
||||||
# Full prefix cache hit: do not need to read remote blocks,
|
# Full prefix cache hit: do not need to read remote blocks,
|
||||||
# just notify P worker that we have the blocks we need.
|
# just notify P worker that we have the blocks we need.
|
||||||
num_local_blocks = len(local_block_ids)
|
num_local_blocks = len(local_block_ids)
|
||||||
if num_local_blocks == 0:
|
if num_local_blocks == 0:
|
||||||
remote_rank = self.tp_rank // tp_ratio
|
|
||||||
agent_name = self._remote_agents[dst_engine_id][remote_rank]
|
agent_name = self._remote_agents[dst_engine_id][remote_rank]
|
||||||
try:
|
try:
|
||||||
self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id)
|
self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id)
|
||||||
@ -1524,10 +1711,6 @@ class NixlConnectorWorker:
|
|||||||
if num_local_blocks < num_remote_blocks:
|
if num_local_blocks < num_remote_blocks:
|
||||||
remote_block_ids = remote_block_ids[-num_local_blocks:]
|
remote_block_ids = remote_block_ids[-num_local_blocks:]
|
||||||
|
|
||||||
# Get side handles.
|
|
||||||
local_xfer_side_handle = self.src_xfer_side_handle
|
|
||||||
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id]
|
|
||||||
|
|
||||||
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
|
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
|
||||||
# corresponding rank. With heterogeneous TP, fixing D>P, the D tp
|
# corresponding rank. With heterogeneous TP, fixing D>P, the D tp
|
||||||
# workers will issue xfers to parts of the P worker remote kv caches.
|
# workers will issue xfers to parts of the P worker remote kv caches.
|
||||||
@ -1680,15 +1863,20 @@ class NixlConnectorWorker:
|
|||||||
if self._nixl_handshake_listener_t is not None:
|
if self._nixl_handshake_listener_t is not None:
|
||||||
self._nixl_handshake_listener_t.join(timeout=0)
|
self._nixl_handshake_listener_t.join(timeout=0)
|
||||||
self._nixl_handshake_listener_t = None
|
self._nixl_handshake_listener_t = None
|
||||||
for handles in self._recving_transfers.values():
|
for rcv_handles in self._recving_transfers.values():
|
||||||
for handle, _ in handles:
|
for handle, _ in rcv_handles:
|
||||||
self.nixl_wrapper.release_xfer_handle(handle)
|
self.nixl_wrapper.release_xfer_handle(handle)
|
||||||
self._recving_transfers.clear()
|
self._recving_transfers.clear()
|
||||||
if self.src_xfer_side_handle:
|
if self.src_xfer_side_handle:
|
||||||
self.nixl_wrapper.release_dlist_handle(self.src_xfer_side_handle)
|
self.nixl_wrapper.release_dlist_handle(self.src_xfer_side_handle)
|
||||||
self.src_xfer_side_handle = 0
|
self.src_xfer_side_handle = 0
|
||||||
for dst_xfer_side_handle in self.dst_xfer_side_handles.values():
|
for handles in self.src_xfer_side_chunked_handles.values():
|
||||||
self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle)
|
for handle in handles:
|
||||||
|
self.nixl_wrapper.release_dlist_handle(handle)
|
||||||
|
self.src_xfer_side_chunked_handles.clear()
|
||||||
|
for dst_xfer_side_handles in self.dst_xfer_side_handles.values():
|
||||||
|
for dst_xfer_side_handle in dst_xfer_side_handles.values():
|
||||||
|
self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle)
|
||||||
self.dst_xfer_side_handles.clear()
|
self.dst_xfer_side_handles.clear()
|
||||||
for remote_agents in self._remote_agents.values():
|
for remote_agents in self._remote_agents.values():
|
||||||
for agent_name in remote_agents.values():
|
for agent_name in remote_agents.values():
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user