[NIXL] Support P tensor-parallel-size > D tensor-parallel-size (#27274)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi 2025-12-18 04:53:30 +01:00 committed by GitHub
parent fd8afdf38d
commit bc3700e0cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 556 additions and 212 deletions

View File

@ -8,9 +8,12 @@ SCRIPT="v1/kv_connector/nixl_integration/run_accuracy_test.sh"
configs=(
"GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2"
"GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2"
"GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1"
"GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA case
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP1, D-DPEP=2 (TP=1)
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP2, D-DPEP=2 (TP=1)
)
run_tests() {

View File

@ -391,6 +391,8 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
super().__init__(*args, **kwargs)
self._hand_shake_latency = hand_shake_latency
self.kv_cache_layout = kv_cache_layout
# Mock register_kv_caches attribute needed for tests that do not call it.
self.src_xfer_handles_by_block_size = {self.block_size: 1}
def _nixl_handshake(
self, host: str, port: int, remote_tp_size: int, expected_engine_id: str
@ -407,22 +409,43 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
assert expected_engine_id == self.REMOTE_ENGINE_ID
remote_agent_name = self.add_remote_agent(
NixlAgentMetadata(
engine_id=self.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
device_id=0,
num_blocks=1,
block_lens=self.block_len_per_layer,
# `self.kv_cache_layout` is only forced to HND when vllm engine
# is started. We mock HND here.
kv_cache_layout="HND",
block_size=self.block_size,
),
remote_tp_size=remote_tp_size,
)
return {0: remote_agent_name}
# Adjust remote block length metadata to satisfy heterogeneous TP
# invariants enforced during handshake validation.
remote_block_lens = list(self.block_len_per_layer)
tp_ratio = self.kv_topo.tp_ratio(remote_tp_size)
if remote_tp_size > self.world_size:
# P TP > D TP case, block_len of remote is smaller
remote_block_lens = [
block_len // (-tp_ratio) for block_len in remote_block_lens
]
elif remote_tp_size < self.world_size:
remote_block_lens = [
block_len * tp_ratio for block_len in remote_block_lens
]
# When remote tp_size > local tp_size, handshake with multiple
# remote ranks.
num_hanshakes = 1 if tp_ratio > 0 else -tp_ratio
remote_agents: dict[int, str] = {}
for remote_tp_rank in range(num_hanshakes):
remote_agent_name = self.add_remote_agent(
NixlAgentMetadata(
engine_id=self.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
device_id=remote_tp_rank,
num_blocks=1,
block_lens=remote_block_lens,
# `self.kv_cache_layout` is only forced to HND when vllm engine
# is started. We mock HND here.
kv_cache_layout="HND",
block_size=self.block_size,
),
remote_tp_rank=remote_tp_rank,
remote_tp_size=remote_tp_size,
)
remote_agents[remote_tp_rank] = remote_agent_name
return remote_agents
class TestNixlHandshake:
@ -453,7 +476,13 @@ class TestNixlHandshake:
vllm_config, connector.engine_id, hand_shake_latency=0
)
assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper)
connector.connector_worker.nixl_wrapper.set_cycles_before_xfer_done(3)
worker = connector.connector_worker
worker.nixl_wrapper.set_cycles_before_xfer_done(3)
# simulate handshake
worker.dst_xfer_side_handles = {
FakeNixlConnectorWorker.REMOTE_ENGINE_ID: {0: 1}
}
worker.kv_cache_layout = "HND"
num_xfers = 4
while True:
# For the same request_id, initiate multiple xfers across different
@ -567,6 +596,171 @@ class TestNixlHandshake:
return
raise TimeoutError("Took too long to complete async handshake.")
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
@pytest.mark.parametrize("local_tp_size", [1, 2])
def test_prefill_tp_size_greater_than_decode_tp_size(
self, local_tp_size: int, dist_init
):
"""
Verify remote TP > local TP handshake succeeds with different
remote configurations.
"""
vllm_config = create_vllm_config()
local_tp_size = 1
vllm_config.parallel_config.tensor_parallel_size = local_tp_size
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0
)
worker = connector.connector_worker
# Minimal local registration params used by add_remote_agent
worker.slot_size_per_layer = [4096]
worker.block_len_per_layer = [4096 * worker.block_size]
worker.num_blocks = 1
worker.dst_num_blocks[worker.engine_id] = worker.num_blocks
worker.src_blocks_data = [(0, worker.block_len_per_layer[0], worker.tp_rank)]
def check_handshake(remote_tp_size: int):
tp_ratio = remote_tp_size // local_tp_size
assert set(remote_agents.keys()) == set(range(tp_ratio))
remote_engine_id = worker.REMOTE_ENGINE_ID
assert worker._tp_size[remote_engine_id] == remote_tp_size
assert -tp_ratio == worker.kv_topo.tp_ratio_from_engine_id(remote_engine_id)
# ensure src_xfer_handles_by_tp_ratio is populated with tpratio chunks
assert -tp_ratio in worker.src_xfer_handles_by_tp_ratio
assert len(worker.src_xfer_handles_by_tp_ratio[-tp_ratio]) == tp_ratio
assert remote_engine_id in worker.dst_xfer_side_handles
assert set(worker.dst_xfer_side_handles[remote_engine_id].keys()) == set(
range(tp_ratio)
)
remote_agents = worker._nixl_handshake(
host="localhost",
port=1234,
remote_tp_size=2,
expected_engine_id=worker.REMOTE_ENGINE_ID,
)
check_handshake(2)
# NOTE flexiblity: a second remote with higher number of ranks is
# discovered. This is not a scenario we actively support right now, but
# the connector allows it.
worker.REMOTE_ENGINE_ID = "remote_engine_2"
remote_agents = worker._nixl_handshake(
host="localhost",
port=1234,
remote_tp_size=6,
expected_engine_id=worker.REMOTE_ENGINE_ID,
)
check_handshake(6)
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
@pytest.mark.parametrize("local_tp_size", [1, 2])
def test_prefill_tp_size_greater_than_decode_tp_size_mla(
self, local_tp_size: int, dist_init
):
"""
Verify remote TP > local TP handshake succeeds with different
remote configurations for an MLA model.
"""
vllm_config = create_vllm_config()
d_tp_size = 1
p_tp_size = 2
# Build two separate connectors/workers to emulate P TP=2 ranks.
conn_p0 = NixlConnector(vllm_config, KVConnectorRole.WORKER)
conn_p1 = NixlConnector(vllm_config, KVConnectorRole.WORKER)
conn_p0.connector_worker = FakeNixlConnectorWorker(
vllm_config, conn_p0.engine_id, hand_shake_latency=0
)
conn_p1.connector_worker = FakeNixlConnectorWorker(
vllm_config, conn_p1.engine_id, hand_shake_latency=0
)
# Force P world size to 2 for both workers and emulate distinct tp_ranks.
# Also enable MLA path so that expected_finished_count is updated.
for rank, worker in enumerate(
(conn_p0.connector_worker, conn_p1.connector_worker)
):
worker.world_size = p_tp_size
worker.kv_topo.remote_tp_size = {worker.engine_id: p_tp_size}
worker.tp_rank = rank
worker.use_mla = True
req_id = "req-ep-dp2-p0"
now = time.perf_counter()
# Register a request on P that is waiting for consumers to read
# (both workers track it).
conn_p0.connector_worker._reqs_to_send[req_id] = now + 10.0
conn_p0.connector_worker._reqs_to_process.add(req_id)
conn_p1.connector_worker._reqs_to_send[req_id] = now + 10.0
conn_p1.connector_worker._reqs_to_process.add(req_id)
# Simulate a read notification coming from D with (tp=1, dp=2).
notif = f"{req_id}:{d_tp_size}".encode()
# D0-0->P0 notif
conn_p0.connector_worker.nixl_wrapper.get_new_notifs = lambda: {
"agent": [notif]
} # type: ignore[method-assign]
conn_p1.connector_worker.nixl_wrapper.get_new_notifs = lambda: {
"agent": [notif]
} # type: ignore[method-assign]
# Trigger notification processing via get_finished().
done_sending0, _ = conn_p0.get_finished(finished_req_ids=set())
done_sending1, _ = conn_p1.get_finished(finished_req_ids=set())
assert req_id in done_sending0 and req_id in done_sending1
# E2E aggregation: ensure the aggregated output marks the request
# as finished using the connector's expected_finished_count.
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
aggregator = KVOutputAggregator.from_connector(conn_p0, world_size=2)
out0 = ModelRunnerOutput(
req_ids=[req_id],
req_id_to_index={req_id: 0},
sampled_token_ids=[[0]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[None],
kv_connector_output=KVConnectorOutput(
finished_sending=done_sending0,
finished_recving=None,
),
)
out1 = ModelRunnerOutput(
req_ids=[req_id],
req_id_to_index={req_id: 0},
sampled_token_ids=[[0]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[None],
kv_connector_output=KVConnectorOutput(
finished_sending=done_sending1,
finished_recving=None,
),
)
aggregated = aggregator.aggregate([out0, out1], output_rank=0)
assert aggregated.kv_connector_output is not None
assert aggregated.kv_connector_output.finished_sending == {req_id}
# Producers cleaned up state for the finished request.
assert req_id not in conn_p0.connector_worker._reqs_to_send
assert req_id not in conn_p0.connector_worker._reqs_to_process
assert req_id not in conn_p1.connector_worker._reqs_to_send
assert req_id not in conn_p1.connector_worker._reqs_to_process
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
@ -585,6 +779,9 @@ class TestNixlHandshake:
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id
)
# Register (mocked) local xfer handler
# worker = connector.connector_worker
# worker.src_xfer_handles_by_block_size = {worker.block_size: 1}
metadata = NixlConnectorMetadata()
total_reqs = 5
for i in range(total_reqs):
@ -672,7 +869,6 @@ class TestNixlHandshake:
with pytest.raises(RuntimeError):
# mismatched layout is expected to fail
worker.add_remote_agent(meta, remote_tp_size=2)
with pytest.raises(AssertionError):
worker.add_remote_agent(meta, remote_tp_size=1)
@patch(
@ -1357,8 +1553,11 @@ def test_shutdown_cleans_up_resources(dist_init):
patch.object(nixl_wrapper, "deregister_memory") as mock_dereg,
):
worker._recving_transfers = {"req1": [123]}
worker.src_xfer_side_handle = 456
worker.dst_xfer_side_handles = {"engine1": 789}
# Mock register_kv_cache which registers local handle
worker.src_xfer_handles_by_block_size = {worker.block_size: 455}
# P TP = 2 * D TP case, we should register 2 local handles
worker.src_xfer_handles_by_tp_ratio = {-2: [456, 457]}
worker.dst_xfer_side_handles = {"engine1": {0: 789}}
worker._remote_agents = {"engine1": {0: "agent1"}}
worker._registered_descs = ["desc1", "desc2"]
@ -1379,8 +1578,10 @@ def test_shutdown_cleans_up_resources(dist_init):
mock_listener.join.assert_called_once()
mock_rel_xfer.assert_called_once_with(123)
assert mock_rel_dlist.call_count == 2
mock_rel_dlist.assert_any_call(456) # src handle
assert mock_rel_dlist.call_count == 4
mock_rel_dlist.assert_any_call(455) # src handle (whole region)
mock_rel_dlist.assert_any_call(456) # src handle (1st chunk)
mock_rel_dlist.assert_any_call(457) # src handle (2nd chunk)
mock_rel_dlist.assert_any_call(789) # dst handle
mock_rem_agent.assert_called_once_with("agent1")
assert mock_dereg.call_count == 2

View File

@ -21,6 +21,8 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
EngineId = str
def get_kv_connector_cache_layout():
# NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is
@ -209,12 +211,12 @@ class TpKVTopology:
"""
tp_rank: int
remote_tp_size: dict[str, int]
remote_tp_size: dict[EngineId, int]
is_mla: bool
total_num_kv_heads: int
attn_backend: type[AttentionBackend]
engine_id: str
remote_block_size: dict[str, int]
engine_id: EngineId
remote_block_size: dict[EngineId, int]
def __post_init__(self):
# Figure out whether the first dimension of the cache is K/V
@ -256,18 +258,28 @@ class TpKVTopology:
Calculate the tensor parallel ratio between local and remote TP.
We can think of it as the number of local TP workers-per-remote TP
workers. Local workers will read from the same remote TP worker in
groups of size `tp_ratio`.
groups of size `tp_ratio`.If remote tp_size > local tp_size, the
ratio is flipped (remote_size/local_size) and the returned value is
negative.
"""
assert self.tp_size % remote_tp_size == 0, (
f"Local tensor parallel size {self.tp_size} is not divisible "
f"by remote tensor parallel size {remote_tp_size}."
if self.tp_size >= remote_tp_size:
assert self.tp_size % remote_tp_size == 0, (
f"Local tensor parallel size {self.tp_size} is not divisible "
f"by remote tensor parallel size {remote_tp_size}."
)
return self.tp_size // remote_tp_size
assert remote_tp_size % self.tp_size == 0, (
f"Remote tensor parallel size {remote_tp_size} is not divisible "
f"by local tensor parallel size {self.tp_size}."
)
return self.tp_size // remote_tp_size
# P TP > D TP case, return the ratio as negative
return -remote_tp_size // self.tp_size
def block_size_ratio(
self,
remote_block_size: int,
) -> float:
) -> int:
"""
Calculate the block size ratio between local and remote TP.
"""
@ -279,19 +291,19 @@ class TpKVTopology:
def tp_ratio_from_engine_id(
self,
remote_engine_id: str,
remote_engine_id: EngineId,
) -> int:
remote_tp_size = self.remote_tp_size[remote_engine_id]
return self.tp_ratio(remote_tp_size)
def block_size_ratio_from_engine_id(
self,
remote_engine_id: str,
) -> float:
remote_engine_id: EngineId,
) -> int:
remote_block_size = self.remote_block_size[remote_engine_id]
return self.block_size_ratio(remote_block_size)
def is_kv_replicated(self, engine_id: str) -> bool:
def is_kv_replicated(self, engine_id: EngineId) -> bool:
"""
Whether the KV cache is replicated across TP workers due to the
number of TP workers being greater than the number of KV heads.
@ -299,24 +311,30 @@ class TpKVTopology:
tp_size = self.remote_tp_size[engine_id]
return tp_size // self.total_num_kv_heads >= 1
def replicates_kv_cache(self, remote_engine_id: str) -> bool:
def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool:
# MLA is always replicated as the hidden dim can't be split.
return self.is_mla or self.is_kv_replicated(remote_engine_id)
def get_target_remote_rank(
def get_target_remote_ranks(
self,
remote_tp_size: int,
) -> int:
) -> list[int]:
"""
Get the remote TP rank (on P) that the current local TP rank
(on D) will read from.
(on D) will read from. When remote tp_size > local tp_size, we
read from multiple remote ranks.
"""
tp_ratio = self.tp_ratio(remote_tp_size)
return self.tp_rank // tp_ratio
if tp_ratio > 0:
return [self.tp_rank // tp_ratio]
def get_target_remote_rank_from_engine_id(
# P TP > D TP case, D reads from |tp_ratio| remote workers.
tp_ratio = -tp_ratio
return [self.tp_rank * tp_ratio + i for i in range(tp_ratio)]
def get_target_remote_ranks_from_engine_id(
self,
remote_engine_id: str,
) -> int:
remote_engine_id: EngineId,
) -> list[int]:
remote_tp_size = self.remote_tp_size[remote_engine_id]
return self.get_target_remote_rank(remote_tp_size)
return self.get_target_remote_ranks(remote_tp_size)

View File

@ -23,7 +23,7 @@ from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.selector import get_attn_backend
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import TpKVTopology
from vllm.distributed.kv_transfer.kv_connector.utils import EngineId, TpKVTopology
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
CopyBlocksOp,
KVConnectorBase_V1,
@ -56,7 +56,6 @@ if TYPE_CHECKING:
from vllm.v1.request import Request
TransferHandle = int
EngineId = str
ReqId = str
#
@ -873,9 +872,10 @@ class NixlConnectorWorker:
self.copy_blocks: CopyBlocksOp | None = None
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
# rank will still only pull from a single remote TP worker.
self.kv_caches_base_addr: dict[EngineId, list[int]] = {}
self.device_id: int = 0
# Current rank may pull from multiple remote TP workers.
# EngineId, dict[int, list[int]] -> engine_id, tp_rank, base_addr_for_layer
self.kv_caches_base_addr = defaultdict[EngineId, dict[int, list[int]]](dict)
# Number of NIXL regions. Currently one region per cache
# (so 1 per layer for MLA, otherwise 2 per layer)
@ -883,10 +883,12 @@ class NixlConnectorWorker:
self.num_layers = 0
# nixl_prepped_dlist_handle.
self.src_xfer_side_handle: int = 0
self.src_xfer_side_handles: dict[int, int] = {}
# Map of engine_id -> nixl_prepped_dlist_handle (int)].
self.dst_xfer_side_handles: dict[EngineId, int] = {}
self.src_xfer_handles_by_block_size: dict[int, int] = {}
# Populated dynamically during handshake based on remote configuration.
# Keep track of regions at different tp_ratio values. tp_ratio->handles
self.src_xfer_handles_by_tp_ratio: dict[int, list[int]] = {}
# Map of engine_id -> {tp_rank: nixl_prepped_dlist_handle (int)}.
self.dst_xfer_side_handles = defaultdict[EngineId, dict[int, int]](dict)
# Map of engine_id -> num_blocks. All ranks in the same deployment will
# have the same number of blocks.
@ -977,103 +979,108 @@ class NixlConnectorWorker:
expected_engine_id: str,
) -> dict[int, str]:
"""Do a NIXL handshake with a remote instance."""
start_time = time.perf_counter()
# NOTE(rob): we need each rank to have a unique port. This is
# a hack to keep us moving. We will switch when moving to etcd
# or where we have a single ZMQ socket in the scheduler.
# Handshake only with the remote TP rank that current local rank will
# pull from. With homogeneous TP it happens to be the same rank_i.
p_remote_rank = self.kv_topo.get_target_remote_rank(remote_tp_size)
# When target instance TP > local TP, we need to perform multiple
# handshakes. Do it in a single background job for simplicity.
# Regardless, only handshake with the remote TP rank(s) that current
# local rank will read from. Note that With homogeneous TP,
# this happens to be the same single rank_i.
p_remote_ranks = self.kv_topo.get_target_remote_ranks(remote_tp_size)
remote_rank_to_agent_name = {}
path = make_zmq_path("tcp", host, port)
logger.debug(
"Querying metadata on path: %s at remote tp rank %s", path, p_remote_rank
)
# Send query for the request.
with zmq_ctx(zmq.REQ, path) as sock:
msg = msgspec.msgpack.encode((GET_META_MSG, p_remote_rank))
# Set receive timeout to 5 seconds to avoid hanging on dead server
sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds
sock.send(msg)
handshake_bytes = sock.recv()
# Decode handshake payload to get compatibility hash
handshake_decoder = msgspec.msgpack.Decoder(NixlHandshakePayload)
try:
handshake_payload = handshake_decoder.decode(handshake_bytes)
except (msgspec.DecodeError, msgspec.ValidationError) as e:
raise RuntimeError(
f"Failed to decode NixlHandshakePayload. This likely indicates "
f"an incompatibility between connector version. Error: {e}"
) from e
got_metadata_time = time.perf_counter()
logger.debug(
"NIXL handshake: get metadata took: %s", got_metadata_time - start_time
)
# Check compatibility hash BEFORE decoding agent metadata
if (
self.enforce_compat_hash
and handshake_payload.compatibility_hash != self.compat_hash
):
raise RuntimeError(
f"NIXL compatibility hash mismatch. "
f"Local: {self.compat_hash}, "
f"Remote: {handshake_payload.compatibility_hash}. "
f"Prefill and decode instances have incompatible configurations. "
f"This may be due to: different vLLM versions, models, dtypes, "
f"KV cache layouts, attention backends, etc. "
f"Both instances must use identical configurations."
f"Disable this check using "
f'--kv-transfer-config \'{{"kv_connector_extra_config": '
f'{{"enforce_handshake_compat": false}}}}\''
for remote_rank in p_remote_ranks:
logger.debug(
"Querying metadata on path: %s at remote tp rank %s",
path,
remote_rank,
)
logger.info(
"NIXL compatibility check passed (hash: %s)",
handshake_payload.compatibility_hash,
)
start_time = time.perf_counter()
# Send query for the request.
msg = msgspec.msgpack.encode((GET_META_MSG, remote_rank))
# Set receive timeout to 5 seconds to avoid hanging on dead server
sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds
sock.send(msg)
handshake_bytes = sock.recv()
# Decode agent metadata
metadata_decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
try:
metadata = metadata_decoder.decode(
handshake_payload.agent_metadata_bytes
)
except (msgspec.DecodeError, msgspec.ValidationError) as e:
# This should not happen if hash matched
raise RuntimeError(
f"Failed to decode NixlAgentMetadata. Error: {e}"
) from e
# Decode handshake payload to get compatibility hash
handshake_decoder = msgspec.msgpack.Decoder(NixlHandshakePayload)
try:
handshake_payload = handshake_decoder.decode(handshake_bytes)
except (msgspec.DecodeError, msgspec.ValidationError) as e:
raise RuntimeError(
f"Failed to decode NixlHandshakePayload. This likely indicates "
f"an incompatibility between connector version. Error: {e}"
) from e
# Ensure engine id matches.
if metadata.engine_id != expected_engine_id:
raise RuntimeError(
f"Remote NIXL agent engine ID mismatch. "
f"Expected {expected_engine_id},"
f"received {metadata.engine_id}."
got_metadata_time = time.perf_counter()
logger.debug(
"NIXL handshake: get metadata took: %s",
got_metadata_time - start_time,
)
# Register Remote agent.
assert metadata.block_size <= self.block_size, (
"nP > nD is not supported yet."
)
remote_agent_name = self.add_remote_agent(
metadata, p_remote_rank, remote_tp_size
)
# Check compatibility hash BEFORE decoding agent metadata
if (
self.enforce_compat_hash
and handshake_payload.compatibility_hash != self.compat_hash
):
raise RuntimeError(
f"NIXL compatibility hash mismatch. "
f"Local: {self.compat_hash}, "
f"Remote: {handshake_payload.compatibility_hash}. "
f"Prefill and decode instances have incompatible "
f"configurations. This may be due to: different vLLM versions,"
f" models, dtypes, KV cache layouts, attention backends, etc. "
f"Both instances must use identical configurations."
f"Disable this check using "
f'--kv-transfer-config \'{{"kv_connector_extra_config": '
f'{{"enforce_handshake_compat": false}}}}\''
)
setup_agent_time = time.perf_counter()
logger.debug(
"NIXL handshake: add agent took: %s",
setup_agent_time - got_metadata_time,
)
logger.info(
"NIXL compatibility check passed (hash: %s)",
handshake_payload.compatibility_hash,
)
# Remote rank -> agent name.
return {p_remote_rank: remote_agent_name}
# Decode agent metadata
metadata_decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
try:
metadata = metadata_decoder.decode(
handshake_payload.agent_metadata_bytes
)
except (msgspec.DecodeError, msgspec.ValidationError) as e:
# This should not happen if hash matched
raise RuntimeError(
f"Failed to decode NixlAgentMetadata. Error: {e}"
) from e
# Ensure engine id matches.
if metadata.engine_id != expected_engine_id:
raise RuntimeError(
f"Remote NIXL agent engine ID mismatch. "
f"Expected {expected_engine_id},"
f"received {metadata.engine_id}."
)
# Ensure engine id matches.
if metadata.engine_id != expected_engine_id:
raise RuntimeError(
f"Remote NIXL agent engine ID mismatch. "
f"Expected {expected_engine_id},"
f"received {metadata.engine_id}."
)
setup_agent_time = time.perf_counter()
# Register Remote agent.
remote_agent_name = self.add_remote_agent(
metadata, remote_rank, remote_tp_size
)
logger.debug(
"NIXL handshake: add agent took: %s",
setup_agent_time - got_metadata_time,
)
remote_rank_to_agent_name[remote_rank] = remote_agent_name
return remote_rank_to_agent_name
def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> None:
"""
@ -1283,7 +1290,7 @@ class NixlConnectorWorker:
assert len(self.block_len_per_layer) == len(seen_base_addresses)
assert self.num_blocks != 0
self.kv_caches_base_addr[self.engine_id] = seen_base_addresses
self.kv_caches_base_addr[self.engine_id][self.tp_rank] = seen_base_addresses
self.num_regions = len(caches_data)
self.num_layers = len(xfer_buffers.keys())
@ -1310,9 +1317,9 @@ class NixlConnectorWorker:
# Register local/src descr for NIXL xfer.
self.seen_base_addresses = seen_base_addresses
self.src_xfer_side_handle = self.register_local_xfer_handler(self.block_size)
self.src_xfer_side_handles[self.block_size] = self.src_xfer_side_handle
self.src_xfer_handles_by_block_size[self.block_size], self.src_blocks_data = (
self.register_local_xfer_handler(self.block_size)
)
# TODO(mgoin): Hybrid memory allocator is currently disabled for
# models with local attention (Llama 4). Can remove this once enabled.
@ -1340,8 +1347,8 @@ class NixlConnectorWorker:
agent_metadata = NixlAgentMetadata(
engine_id=self.engine_id,
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
device_id=self.device_id,
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id][self.tp_rank],
num_blocks=self.num_blocks,
block_lens=self.block_len_per_layer,
kv_cache_layout=self.kv_cache_layout
@ -1359,7 +1366,7 @@ class NixlConnectorWorker:
def register_local_xfer_handler(
self,
block_size: int,
) -> int:
) -> tuple[int, list[tuple[int, int, int]]]:
"""
Function used for register local xfer handler with local block_size or
Remote block_size.
@ -1407,7 +1414,7 @@ class NixlConnectorWorker:
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
# NIXL_INIT_AGENT to be used for preparations of local descs.
return self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs)
return self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs), blocks_data
def add_remote_agent(
self,
@ -1421,10 +1428,12 @@ class NixlConnectorWorker:
In particular, handle both homogeneous and heterogeneous TP. The former
requires local rank_i to read from remote rank_i.
The latter, assuming D.world_size > P.world_size, requires that two or
more local TP worker share the xfer from a single TP worker.
The latter, in the case of D.world_size < P.world_size, requires that a
local (D) TP worker reads from multiple remote (P) TP workers.
Conversely, assuming D.world_size > P.world_size, two or more local TP
workers will read from a single remote TP worker.
Here's an example (non-MLA case):
Here's an example for the last case described above (non-MLA):
rank_offset p_remote_tp_rank
(kv split no)
@ -1474,9 +1483,6 @@ class NixlConnectorWorker:
nixl_agent_meta.agent_metadata
)
# Handle tp_size>num_kv_heads: replicate KV cache.
replicates_kv_cache = self.kv_topo.replicates_kv_cache(engine_id)
# Create dst descs and xfer side handles. TP workers have same #blocks
# so we only register once per engine_id.
# Example:
@ -1490,14 +1496,52 @@ class NixlConnectorWorker:
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
# Keep track of remote agent kv caches base addresses.
self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr
self.kv_caches_base_addr[engine_id][remote_tp_rank] = (
nixl_agent_meta.kv_caches_base_addr
)
self._validate_remote_agent_handshake(nixl_agent_meta, remote_tp_size)
# Number of D TP workers reading from a single P TP worker. This is
# 1 when P and D `--tensor-parallel-size` match.
# This is 1 when P and D `--tensor-parallel-size` match. Otherwise,
# this is the ratio between the two sizes.
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(engine_id)
# Handle tp_size>num_kv_heads: replicate KV cache.
indexes_into_remote = (
not self.kv_topo.replicates_kv_cache(engine_id) and tp_ratio > 0
)
logger.debug(
"Registering remote agent (%s, rank %s) memory regions with tp_ratio %s",
engine_id,
remote_tp_rank,
tp_ratio,
)
### (Optional) Register local agent memory regions. MLA is not split.
if (
tp_ratio < 0
and not self.use_mla
and tp_ratio not in self.src_xfer_handles_by_tp_ratio
):
# Remote tp_size > local tp_size: read from multiple remote ranks.
# Logically "split" own regions into |tp_ratio| chunks. Mind that
# we only do this once per remote tp_size (replica-friendly).
self.src_xfer_handles_by_tp_ratio[tp_ratio] = []
for i in range(-tp_ratio):
blocks_data = []
for memory_region in self.src_blocks_data:
addr, local_block_len, own_tp_rank = memory_region
# Computing block len layer by layer allows for different
# block sizes to be used.
remote_block_len = local_block_len // (-tp_ratio)
addr = addr + i * remote_block_len
blocks_data.append((addr, remote_block_len, own_tp_rank))
descs = self.nixl_wrapper.get_xfer_descs(
blocks_data, self.nixl_memory_type
)
handle = self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs)
self.src_xfer_handles_by_tp_ratio[tp_ratio].append(handle)
### Register remote agent memory regions
blocks_data = []
# With homogeneous TP, D pulls the whole kv cache from corresponding
@ -1507,14 +1551,19 @@ class NixlConnectorWorker:
# Register all remote blocks, but only the corresponding kv heads.
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
remote_kv_block_len = kv_block_len // block_size_ratio
# Read our whole local region size from remote.
local_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
remote_kv_block_len = local_block_len // block_size_ratio
if block_size_ratio > 1:
# using remote kv_block_len as transfer unit
kv_block_len = remote_kv_block_len
local_block_len = remote_kv_block_len
if tp_ratio < 0 and not self.use_mla:
# Remote tp is bigger: read a chunk of local region from remote
local_block_len = local_block_len // (-tp_ratio)
rank_offset = (
self.tp_rank % tp_ratio * remote_kv_block_len
if not replicates_kv_cache
if indexes_into_remote
else 0
)
for block_id in range(nixl_agent_meta.num_blocks):
@ -1524,7 +1573,7 @@ class NixlConnectorWorker:
# self.block_len == remote_block_len//tp_ratio bytes.
addr = base_addr + block_offset + rank_offset
# (addr, len, device id)
blocks_data.append((addr, kv_block_len, nixl_agent_meta.device_id))
blocks_data.append((addr, local_block_len, nixl_agent_meta.device_id))
if self.kv_topo.is_kv_layout_blocks_first:
# With FlashInfer index V separately to allow head splitting.
@ -1533,7 +1582,7 @@ class NixlConnectorWorker:
addr = base_addr + block_offset + rank_offset
v_addr = addr + nixl_agent_meta.block_lens[i] // 2
blocks_data.append(
(v_addr, kv_block_len, nixl_agent_meta.device_id)
(v_addr, local_block_len, nixl_agent_meta.device_id)
)
logger.debug(
@ -1546,15 +1595,15 @@ class NixlConnectorWorker:
# Register with NIXL.
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
self.dst_xfer_side_handles[engine_id] = self.nixl_wrapper.prep_xfer_dlist(
remote_agent_name, descs
self.dst_xfer_side_handles[engine_id][remote_tp_rank] = (
self.nixl_wrapper.prep_xfer_dlist(remote_agent_name, descs)
)
if block_size_ratio > 1:
# when prefill with smaller block_size, we need to init a
# new handler with same block_len to match
self.src_xfer_side_handles[nixl_agent_meta.block_size] = (
self.register_local_xfer_handler(nixl_agent_meta.block_size)
self.src_xfer_handles_by_block_size[nixl_agent_meta.block_size] = (
self.register_local_xfer_handler(nixl_agent_meta.block_size)[0]
)
return remote_agent_name
@ -1574,7 +1623,9 @@ class NixlConnectorWorker:
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(
remote_engine_id
)
assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP"
# Num kv_heads > tp_size and P TP > D TP case, not supported
assert not (tp_ratio < 0 and self.kv_topo.is_kv_replicated(remote_engine_id))
assert not self._use_pallas or tp_ratio == 1, (
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
)
@ -1616,17 +1667,29 @@ class NixlConnectorWorker:
"All remote layers must have the same block size"
)
assert (
remote_block_len
== (self.block_len_per_layer[0] * tp_ratio) // block_size_ratio
), (
"Remote P worker KV layer cache must be of shape [2, N, "
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
)
if tp_ratio > 0:
# Remote tp is smaller: remote block_len size is bigger
assert (
remote_block_len
== (self.block_len_per_layer[0] * tp_ratio) // block_size_ratio
), (
"Remote P worker KV layer cache must be of shape [2, N, "
"local_kv_heads*tp_ratio, page_size, head_dim] and same dtype."
) # noqa: E501
else:
assert block_size_ratio == 1, (
"Different local/remote block sizes are not supported when"
" P TP > D TP."
)
# Remote tp is bigger: remote block_len size is smaller
assert remote_block_len == self.block_len_per_layer[0] // (-tp_ratio), (
"Remote P worker KV layer cache must be of shape [2, N, "
"local_kv_heads/tp_ratio, page_size, head_dim] and same dtype."
) # noqa: E501
# TP workers have same #blocks.
# TP workers that handhshake with same remote have same #blocks.
assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks
# Same number of regions/~layers.
assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer)
def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta):
@ -1710,7 +1773,7 @@ class NixlConnectorWorker:
)
cache.index_copy_(0, indices, permuted_blocks)
def blocksize_post_process(self, block_ids_per_ratio: dict[float, list[list[int]]]):
def blocksize_post_process(self, block_ids_per_ratio: dict[int, list[list[int]]]):
def _process_local_gt_remote(blocks_to_update, block_size_ratio):
n_kv_heads, block_size, head_size = blocks_to_update.shape[1:]
remote_block_size = block_size // block_size_ratio
@ -1840,7 +1903,7 @@ class NixlConnectorWorker:
notified_req_ids: set[str] = set()
for notifs in self.nixl_wrapper.get_new_notifs().values():
for notif in notifs:
req_id, tp_ratio = notif.decode("utf-8").rsplit(":", 1)
req_id, tp_size = notif.decode("utf-8").rsplit(":", 1)
if (
req_id not in self._reqs_to_send
and req_id not in self._reqs_to_process
@ -1853,9 +1916,22 @@ class NixlConnectorWorker:
)
continue
# NOTE: `tp_ratio` is the opposite when swapping local<>remote
n_consumers = int(tp_size)
tp_ratio = self.kv_topo.tp_ratio(n_consumers)
# Number of reads *per producer* to wait for.
# When remote D TP > local P TP we expect `tp_ratio` reads.
consumers_per_producer = (
-tp_ratio if n_consumers > self.world_size else 1
)
self.consumer_notification_counts_by_req[req_id] += 1
# Wait all consumers (D) to be done reading before freeing.
if self.consumer_notification_counts_by_req[req_id] == int(tp_ratio):
if (
self.consumer_notification_counts_by_req[req_id]
== consumers_per_producer
):
notified_req_ids.add(req_id)
del self.consumer_notification_counts_by_req[req_id]
self._reqs_to_process.remove(req_id)
@ -1872,7 +1948,7 @@ class NixlConnectorWorker:
"""
done_req_ids: set[str] = set()
for req_id, handles in list(transfers.items()):
in_progress = False
in_progress = []
for handle in handles:
try:
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
@ -1882,7 +1958,7 @@ class NixlConnectorWorker:
self.xfer_stats.record_transfer(res)
self.nixl_wrapper.release_xfer_handle(handle)
elif xfer_state == "PROC":
in_progress = True
in_progress.append(handle)
continue
else:
logger.error(
@ -1892,7 +1968,6 @@ class NixlConnectorWorker:
xfer_state,
)
self._handle_failed_transfer(req_id, handle)
in_progress = False
except Exception:
logger.exception(
"NIXL transfer exception for request %s. "
@ -1900,11 +1975,13 @@ class NixlConnectorWorker:
req_id,
)
self._handle_failed_transfer(req_id, handle)
in_progress = False
if not in_progress:
# Only report request as completed when all transfers are done.
done_req_ids.add(req_id)
del transfers[req_id]
else:
transfers[req_id] = in_progress
return done_req_ids
def _handle_failed_transfer(self, req_id: str, handle: int):
@ -1982,18 +2059,62 @@ class NixlConnectorWorker:
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
assert meta.remote is not None
logger.debug(
"Remote agent %s available, calling _read_blocks for req %s",
meta.remote.engine_id,
req_id,
)
self._read_blocks(
request_id=req_id,
dst_engine_id=meta.remote.engine_id,
remote_request_id=meta.remote.request_id,
local_block_ids=meta.local_physical_block_ids,
remote_block_ids=meta.remote.block_ids,
remote_ranks = self.kv_topo.get_target_remote_ranks_from_engine_id(
meta.remote.engine_id
)
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(meta.remote.engine_id)
# D may have to perform multiple reads from different remote ranks.
for i, remote_rank in enumerate(remote_ranks):
if self.use_mla and tp_ratio < 0 and i > 0:
# MLA opt: when P TP > D TP, only a single read is executed for
# the first remote rank (cache is duplicated)..
break
remote_block_size = self.kv_topo.remote_block_size[meta.remote.engine_id]
logger.debug(
"Remote agent %s available, calling _read_blocks"
" on remote rank %s with remote block size %s for req %s",
meta.remote.engine_id,
remote_rank,
remote_block_size,
req_id,
)
# Get side handles.
if tp_ratio < 0 and not self.use_mla:
assert remote_block_size == self.block_size
# Remote tp_size > local tp_size: we must perform multiple
# reads. Get the memory chunk onto which we will write to.
local_xfer_side_handle = self.src_xfer_handles_by_tp_ratio[tp_ratio][i]
else:
# Single read from remote, we write to the whole memory region.
# Also handle remote block size different from local block size.
local_xfer_side_handle = self.src_xfer_handles_by_block_size[
remote_block_size
]
# Destination handle: remote_engine_id -> remote_rank -> handle.
remote_xfer_side_handle = self.dst_xfer_side_handles[meta.remote.engine_id][
remote_rank
]
self._read_blocks(
request_id=req_id,
dst_engine_id=meta.remote.engine_id,
remote_request_id=meta.remote.request_id,
local_block_ids=meta.local_physical_block_ids,
remote_block_ids=meta.remote.block_ids,
remote_rank=remote_rank,
local_xfer_side_handle=local_xfer_side_handle,
remote_xfer_side_handle=remote_xfer_side_handle,
)
if self.use_mla and tp_ratio < 0:
# ..but we still need to notify the other remote ranks that we
# have the blocks we need so they can update the request state.
notif_id = f"{req_id}:{self.world_size}".encode()
remote_agents = self._remote_agents[meta.remote.engine_id]
for rank_to_notify, agent in remote_agents.items():
if rank_to_notify != remote_rank:
self.nixl_wrapper.send_notif(agent, notif_msg=notif_id)
def _read_blocks(
self,
@ -2002,7 +2123,14 @@ class NixlConnectorWorker:
dst_engine_id: str,
request_id: str,
remote_request_id: str,
remote_rank: int,
local_xfer_side_handle: int,
remote_xfer_side_handle: int,
):
"""
Post a READ point-to-point xfer request from a single local worker to
a single remote worker.
"""
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id)
if block_size_ratio > 1:
local_block_ids = self.get_mapped_blocks(
@ -2031,18 +2159,14 @@ class NixlConnectorWorker:
# saturate IB with heterogeneous TP sizes. We should remove the staging
# blocks until we are ready.
# Number of D TP workers that will read from dst P. Propagate tp_ratio
# Number of D TP workers that will read from dst P. Propagate info
# on notification so that dst worker can wait before freeing blocks.
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(dst_engine_id)
notif_id = f"{remote_request_id}:{tp_ratio}".encode()
notif_id = f"{remote_request_id}:{self.world_size}".encode()
# Full prefix cache hit: do not need to read remote blocks,
# just notify P worker that we have the blocks we need.
num_local_blocks = len(local_block_ids)
if num_local_blocks == 0:
remote_rank = self.kv_topo.get_target_remote_rank_from_engine_id(
dst_engine_id
)
agent_name = self._remote_agents[dst_engine_id][remote_rank]
try:
self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id)
@ -2062,13 +2186,6 @@ class NixlConnectorWorker:
if num_local_blocks < num_remote_blocks:
remote_block_ids = remote_block_ids[-num_local_blocks:]
# Get side handles.
remote_block_size = self.kv_topo.remote_block_size[dst_engine_id]
local_xfer_side_handle = self.src_xfer_side_handles.get(
remote_block_size, self.src_xfer_side_handle
)
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id]
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
# corresponding rank. With heterogeneous TP, fixing D>P, the D tp
# workers will issue xfers to parts of the P worker remote kv caches.
@ -2230,7 +2347,7 @@ class NixlConnectorWorker:
block_ids_np, self._physical_blocks_per_logical_kv_block, block_arange
).tolist()
def get_backend_aware_kv_block_len(self, layer_idx: int):
def get_backend_aware_kv_block_len(self, layer_idx: int) -> int:
"""
Get the block length for one K/V element (K and V have the same size).
@ -2276,11 +2393,16 @@ class NixlConnectorWorker:
for handle in handles:
self.nixl_wrapper.release_xfer_handle(handle)
self._recving_transfers.clear()
if self.src_xfer_side_handle:
self.nixl_wrapper.release_dlist_handle(self.src_xfer_side_handle)
self.src_xfer_side_handle = 0
for dst_xfer_side_handle in self.dst_xfer_side_handles.values():
self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle)
for handle in self.src_xfer_handles_by_block_size.values():
self.nixl_wrapper.release_dlist_handle(handle)
self.src_xfer_handles_by_block_size.clear()
for handles in self.src_xfer_handles_by_tp_ratio.values():
for handle in handles:
self.nixl_wrapper.release_dlist_handle(handle)
self.src_xfer_handles_by_tp_ratio.clear()
for dst_xfer_side_handles in self.dst_xfer_side_handles.values():
for dst_xfer_side_handle in dst_xfer_side_handles.values():
self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle)
self.dst_xfer_side_handles.clear()
for remote_agents in self._remote_agents.values():
for agent_name in remote_agents.values():