[NIXL] Support P tensor-parallel-size > D tensor-parallel-size (#27274)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi 2025-12-18 04:53:30 +01:00 committed by GitHub
parent fd8afdf38d
commit bc3700e0cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 556 additions and 212 deletions

View File

@ -8,9 +8,12 @@ SCRIPT="v1/kv_connector/nixl_integration/run_accuracy_test.sh"
configs=( configs=(
"GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2" "GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2"
"GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2" "GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2"
"GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1"
"GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA case "GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA case
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" "GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP1, D-DPEP=2 (TP=1) "DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP1, D-DPEP=2 (TP=1)
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP2, D-DPEP=2 (TP=1)
) )
run_tests() { run_tests() {

View File

@ -391,6 +391,8 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._hand_shake_latency = hand_shake_latency self._hand_shake_latency = hand_shake_latency
self.kv_cache_layout = kv_cache_layout self.kv_cache_layout = kv_cache_layout
# Mock register_kv_caches attribute needed for tests that do not call it.
self.src_xfer_handles_by_block_size = {self.block_size: 1}
def _nixl_handshake( def _nixl_handshake(
self, host: str, port: int, remote_tp_size: int, expected_engine_id: str self, host: str, port: int, remote_tp_size: int, expected_engine_id: str
@ -407,22 +409,43 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
assert expected_engine_id == self.REMOTE_ENGINE_ID assert expected_engine_id == self.REMOTE_ENGINE_ID
remote_agent_name = self.add_remote_agent( # Adjust remote block length metadata to satisfy heterogeneous TP
NixlAgentMetadata( # invariants enforced during handshake validation.
engine_id=self.REMOTE_ENGINE_ID, remote_block_lens = list(self.block_len_per_layer)
agent_metadata=FakeNixlWrapper.AGENT_METADATA, tp_ratio = self.kv_topo.tp_ratio(remote_tp_size)
kv_caches_base_addr=[0], if remote_tp_size > self.world_size:
device_id=0, # P TP > D TP case, block_len of remote is smaller
num_blocks=1, remote_block_lens = [
block_lens=self.block_len_per_layer, block_len // (-tp_ratio) for block_len in remote_block_lens
# `self.kv_cache_layout` is only forced to HND when vllm engine ]
# is started. We mock HND here. elif remote_tp_size < self.world_size:
kv_cache_layout="HND", remote_block_lens = [
block_size=self.block_size, block_len * tp_ratio for block_len in remote_block_lens
), ]
remote_tp_size=remote_tp_size,
) # When remote tp_size > local tp_size, handshake with multiple
return {0: remote_agent_name} # remote ranks.
num_hanshakes = 1 if tp_ratio > 0 else -tp_ratio
remote_agents: dict[int, str] = {}
for remote_tp_rank in range(num_hanshakes):
remote_agent_name = self.add_remote_agent(
NixlAgentMetadata(
engine_id=self.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
device_id=remote_tp_rank,
num_blocks=1,
block_lens=remote_block_lens,
# `self.kv_cache_layout` is only forced to HND when vllm engine
# is started. We mock HND here.
kv_cache_layout="HND",
block_size=self.block_size,
),
remote_tp_rank=remote_tp_rank,
remote_tp_size=remote_tp_size,
)
remote_agents[remote_tp_rank] = remote_agent_name
return remote_agents
class TestNixlHandshake: class TestNixlHandshake:
@ -453,7 +476,13 @@ class TestNixlHandshake:
vllm_config, connector.engine_id, hand_shake_latency=0 vllm_config, connector.engine_id, hand_shake_latency=0
) )
assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper) assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper)
connector.connector_worker.nixl_wrapper.set_cycles_before_xfer_done(3) worker = connector.connector_worker
worker.nixl_wrapper.set_cycles_before_xfer_done(3)
# simulate handshake
worker.dst_xfer_side_handles = {
FakeNixlConnectorWorker.REMOTE_ENGINE_ID: {0: 1}
}
worker.kv_cache_layout = "HND"
num_xfers = 4 num_xfers = 4
while True: while True:
# For the same request_id, initiate multiple xfers across different # For the same request_id, initiate multiple xfers across different
@ -567,6 +596,171 @@ class TestNixlHandshake:
return return
raise TimeoutError("Took too long to complete async handshake.") raise TimeoutError("Took too long to complete async handshake.")
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
@pytest.mark.parametrize("local_tp_size", [1, 2])
def test_prefill_tp_size_greater_than_decode_tp_size(
self, local_tp_size: int, dist_init
):
"""
Verify remote TP > local TP handshake succeeds with different
remote configurations.
"""
vllm_config = create_vllm_config()
local_tp_size = 1
vllm_config.parallel_config.tensor_parallel_size = local_tp_size
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0
)
worker = connector.connector_worker
# Minimal local registration params used by add_remote_agent
worker.slot_size_per_layer = [4096]
worker.block_len_per_layer = [4096 * worker.block_size]
worker.num_blocks = 1
worker.dst_num_blocks[worker.engine_id] = worker.num_blocks
worker.src_blocks_data = [(0, worker.block_len_per_layer[0], worker.tp_rank)]
def check_handshake(remote_tp_size: int):
tp_ratio = remote_tp_size // local_tp_size
assert set(remote_agents.keys()) == set(range(tp_ratio))
remote_engine_id = worker.REMOTE_ENGINE_ID
assert worker._tp_size[remote_engine_id] == remote_tp_size
assert -tp_ratio == worker.kv_topo.tp_ratio_from_engine_id(remote_engine_id)
# ensure src_xfer_handles_by_tp_ratio is populated with tpratio chunks
assert -tp_ratio in worker.src_xfer_handles_by_tp_ratio
assert len(worker.src_xfer_handles_by_tp_ratio[-tp_ratio]) == tp_ratio
assert remote_engine_id in worker.dst_xfer_side_handles
assert set(worker.dst_xfer_side_handles[remote_engine_id].keys()) == set(
range(tp_ratio)
)
remote_agents = worker._nixl_handshake(
host="localhost",
port=1234,
remote_tp_size=2,
expected_engine_id=worker.REMOTE_ENGINE_ID,
)
check_handshake(2)
# NOTE flexiblity: a second remote with higher number of ranks is
# discovered. This is not a scenario we actively support right now, but
# the connector allows it.
worker.REMOTE_ENGINE_ID = "remote_engine_2"
remote_agents = worker._nixl_handshake(
host="localhost",
port=1234,
remote_tp_size=6,
expected_engine_id=worker.REMOTE_ENGINE_ID,
)
check_handshake(6)
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
@pytest.mark.parametrize("local_tp_size", [1, 2])
def test_prefill_tp_size_greater_than_decode_tp_size_mla(
self, local_tp_size: int, dist_init
):
"""
Verify remote TP > local TP handshake succeeds with different
remote configurations for an MLA model.
"""
vllm_config = create_vllm_config()
d_tp_size = 1
p_tp_size = 2
# Build two separate connectors/workers to emulate P TP=2 ranks.
conn_p0 = NixlConnector(vllm_config, KVConnectorRole.WORKER)
conn_p1 = NixlConnector(vllm_config, KVConnectorRole.WORKER)
conn_p0.connector_worker = FakeNixlConnectorWorker(
vllm_config, conn_p0.engine_id, hand_shake_latency=0
)
conn_p1.connector_worker = FakeNixlConnectorWorker(
vllm_config, conn_p1.engine_id, hand_shake_latency=0
)
# Force P world size to 2 for both workers and emulate distinct tp_ranks.
# Also enable MLA path so that expected_finished_count is updated.
for rank, worker in enumerate(
(conn_p0.connector_worker, conn_p1.connector_worker)
):
worker.world_size = p_tp_size
worker.kv_topo.remote_tp_size = {worker.engine_id: p_tp_size}
worker.tp_rank = rank
worker.use_mla = True
req_id = "req-ep-dp2-p0"
now = time.perf_counter()
# Register a request on P that is waiting for consumers to read
# (both workers track it).
conn_p0.connector_worker._reqs_to_send[req_id] = now + 10.0
conn_p0.connector_worker._reqs_to_process.add(req_id)
conn_p1.connector_worker._reqs_to_send[req_id] = now + 10.0
conn_p1.connector_worker._reqs_to_process.add(req_id)
# Simulate a read notification coming from D with (tp=1, dp=2).
notif = f"{req_id}:{d_tp_size}".encode()
# D0-0->P0 notif
conn_p0.connector_worker.nixl_wrapper.get_new_notifs = lambda: {
"agent": [notif]
} # type: ignore[method-assign]
conn_p1.connector_worker.nixl_wrapper.get_new_notifs = lambda: {
"agent": [notif]
} # type: ignore[method-assign]
# Trigger notification processing via get_finished().
done_sending0, _ = conn_p0.get_finished(finished_req_ids=set())
done_sending1, _ = conn_p1.get_finished(finished_req_ids=set())
assert req_id in done_sending0 and req_id in done_sending1
# E2E aggregation: ensure the aggregated output marks the request
# as finished using the connector's expected_finished_count.
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
aggregator = KVOutputAggregator.from_connector(conn_p0, world_size=2)
out0 = ModelRunnerOutput(
req_ids=[req_id],
req_id_to_index={req_id: 0},
sampled_token_ids=[[0]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[None],
kv_connector_output=KVConnectorOutput(
finished_sending=done_sending0,
finished_recving=None,
),
)
out1 = ModelRunnerOutput(
req_ids=[req_id],
req_id_to_index={req_id: 0},
sampled_token_ids=[[0]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[None],
kv_connector_output=KVConnectorOutput(
finished_sending=done_sending1,
finished_recving=None,
),
)
aggregated = aggregator.aggregate([out0, out1], output_rank=0)
assert aggregated.kv_connector_output is not None
assert aggregated.kv_connector_output.finished_sending == {req_id}
# Producers cleaned up state for the finished request.
assert req_id not in conn_p0.connector_worker._reqs_to_send
assert req_id not in conn_p0.connector_worker._reqs_to_process
assert req_id not in conn_p1.connector_worker._reqs_to_send
assert req_id not in conn_p1.connector_worker._reqs_to_process
@patch( @patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper, FakeNixlWrapper,
@ -585,6 +779,9 @@ class TestNixlHandshake:
connector.connector_worker = FakeNixlConnectorWorker( connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id vllm_config, connector.engine_id
) )
# Register (mocked) local xfer handler
# worker = connector.connector_worker
# worker.src_xfer_handles_by_block_size = {worker.block_size: 1}
metadata = NixlConnectorMetadata() metadata = NixlConnectorMetadata()
total_reqs = 5 total_reqs = 5
for i in range(total_reqs): for i in range(total_reqs):
@ -672,7 +869,6 @@ class TestNixlHandshake:
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
# mismatched layout is expected to fail # mismatched layout is expected to fail
worker.add_remote_agent(meta, remote_tp_size=2) worker.add_remote_agent(meta, remote_tp_size=2)
with pytest.raises(AssertionError):
worker.add_remote_agent(meta, remote_tp_size=1) worker.add_remote_agent(meta, remote_tp_size=1)
@patch( @patch(
@ -1357,8 +1553,11 @@ def test_shutdown_cleans_up_resources(dist_init):
patch.object(nixl_wrapper, "deregister_memory") as mock_dereg, patch.object(nixl_wrapper, "deregister_memory") as mock_dereg,
): ):
worker._recving_transfers = {"req1": [123]} worker._recving_transfers = {"req1": [123]}
worker.src_xfer_side_handle = 456 # Mock register_kv_cache which registers local handle
worker.dst_xfer_side_handles = {"engine1": 789} worker.src_xfer_handles_by_block_size = {worker.block_size: 455}
# P TP = 2 * D TP case, we should register 2 local handles
worker.src_xfer_handles_by_tp_ratio = {-2: [456, 457]}
worker.dst_xfer_side_handles = {"engine1": {0: 789}}
worker._remote_agents = {"engine1": {0: "agent1"}} worker._remote_agents = {"engine1": {0: "agent1"}}
worker._registered_descs = ["desc1", "desc2"] worker._registered_descs = ["desc1", "desc2"]
@ -1379,8 +1578,10 @@ def test_shutdown_cleans_up_resources(dist_init):
mock_listener.join.assert_called_once() mock_listener.join.assert_called_once()
mock_rel_xfer.assert_called_once_with(123) mock_rel_xfer.assert_called_once_with(123)
assert mock_rel_dlist.call_count == 2 assert mock_rel_dlist.call_count == 4
mock_rel_dlist.assert_any_call(456) # src handle mock_rel_dlist.assert_any_call(455) # src handle (whole region)
mock_rel_dlist.assert_any_call(456) # src handle (1st chunk)
mock_rel_dlist.assert_any_call(457) # src handle (2nd chunk)
mock_rel_dlist.assert_any_call(789) # dst handle mock_rel_dlist.assert_any_call(789) # dst handle
mock_rem_agent.assert_called_once_with("agent1") mock_rem_agent.assert_called_once_with("agent1")
assert mock_dereg.call_count == 2 assert mock_dereg.call_count == 2

View File

@ -21,6 +21,8 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
EngineId = str
def get_kv_connector_cache_layout(): def get_kv_connector_cache_layout():
# NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is # NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is
@ -209,12 +211,12 @@ class TpKVTopology:
""" """
tp_rank: int tp_rank: int
remote_tp_size: dict[str, int] remote_tp_size: dict[EngineId, int]
is_mla: bool is_mla: bool
total_num_kv_heads: int total_num_kv_heads: int
attn_backend: type[AttentionBackend] attn_backend: type[AttentionBackend]
engine_id: str engine_id: EngineId
remote_block_size: dict[str, int] remote_block_size: dict[EngineId, int]
def __post_init__(self): def __post_init__(self):
# Figure out whether the first dimension of the cache is K/V # Figure out whether the first dimension of the cache is K/V
@ -256,18 +258,28 @@ class TpKVTopology:
Calculate the tensor parallel ratio between local and remote TP. Calculate the tensor parallel ratio between local and remote TP.
We can think of it as the number of local TP workers-per-remote TP We can think of it as the number of local TP workers-per-remote TP
workers. Local workers will read from the same remote TP worker in workers. Local workers will read from the same remote TP worker in
groups of size `tp_ratio`. groups of size `tp_ratio`.If remote tp_size > local tp_size, the
ratio is flipped (remote_size/local_size) and the returned value is
negative.
""" """
assert self.tp_size % remote_tp_size == 0, ( if self.tp_size >= remote_tp_size:
f"Local tensor parallel size {self.tp_size} is not divisible " assert self.tp_size % remote_tp_size == 0, (
f"by remote tensor parallel size {remote_tp_size}." f"Local tensor parallel size {self.tp_size} is not divisible "
f"by remote tensor parallel size {remote_tp_size}."
)
return self.tp_size // remote_tp_size
assert remote_tp_size % self.tp_size == 0, (
f"Remote tensor parallel size {remote_tp_size} is not divisible "
f"by local tensor parallel size {self.tp_size}."
) )
return self.tp_size // remote_tp_size # P TP > D TP case, return the ratio as negative
return -remote_tp_size // self.tp_size
def block_size_ratio( def block_size_ratio(
self, self,
remote_block_size: int, remote_block_size: int,
) -> float: ) -> int:
""" """
Calculate the block size ratio between local and remote TP. Calculate the block size ratio between local and remote TP.
""" """
@ -279,19 +291,19 @@ class TpKVTopology:
def tp_ratio_from_engine_id( def tp_ratio_from_engine_id(
self, self,
remote_engine_id: str, remote_engine_id: EngineId,
) -> int: ) -> int:
remote_tp_size = self.remote_tp_size[remote_engine_id] remote_tp_size = self.remote_tp_size[remote_engine_id]
return self.tp_ratio(remote_tp_size) return self.tp_ratio(remote_tp_size)
def block_size_ratio_from_engine_id( def block_size_ratio_from_engine_id(
self, self,
remote_engine_id: str, remote_engine_id: EngineId,
) -> float: ) -> int:
remote_block_size = self.remote_block_size[remote_engine_id] remote_block_size = self.remote_block_size[remote_engine_id]
return self.block_size_ratio(remote_block_size) return self.block_size_ratio(remote_block_size)
def is_kv_replicated(self, engine_id: str) -> bool: def is_kv_replicated(self, engine_id: EngineId) -> bool:
""" """
Whether the KV cache is replicated across TP workers due to the Whether the KV cache is replicated across TP workers due to the
number of TP workers being greater than the number of KV heads. number of TP workers being greater than the number of KV heads.
@ -299,24 +311,30 @@ class TpKVTopology:
tp_size = self.remote_tp_size[engine_id] tp_size = self.remote_tp_size[engine_id]
return tp_size // self.total_num_kv_heads >= 1 return tp_size // self.total_num_kv_heads >= 1
def replicates_kv_cache(self, remote_engine_id: str) -> bool: def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool:
# MLA is always replicated as the hidden dim can't be split. # MLA is always replicated as the hidden dim can't be split.
return self.is_mla or self.is_kv_replicated(remote_engine_id) return self.is_mla or self.is_kv_replicated(remote_engine_id)
def get_target_remote_rank( def get_target_remote_ranks(
self, self,
remote_tp_size: int, remote_tp_size: int,
) -> int: ) -> list[int]:
""" """
Get the remote TP rank (on P) that the current local TP rank Get the remote TP rank (on P) that the current local TP rank
(on D) will read from. (on D) will read from. When remote tp_size > local tp_size, we
read from multiple remote ranks.
""" """
tp_ratio = self.tp_ratio(remote_tp_size) tp_ratio = self.tp_ratio(remote_tp_size)
return self.tp_rank // tp_ratio if tp_ratio > 0:
return [self.tp_rank // tp_ratio]
def get_target_remote_rank_from_engine_id( # P TP > D TP case, D reads from |tp_ratio| remote workers.
tp_ratio = -tp_ratio
return [self.tp_rank * tp_ratio + i for i in range(tp_ratio)]
def get_target_remote_ranks_from_engine_id(
self, self,
remote_engine_id: str, remote_engine_id: EngineId,
) -> int: ) -> list[int]:
remote_tp_size = self.remote_tp_size[remote_engine_id] remote_tp_size = self.remote_tp_size[remote_engine_id]
return self.get_target_remote_rank(remote_tp_size) return self.get_target_remote_ranks(remote_tp_size)

View File

@ -23,7 +23,7 @@ from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import TpKVTopology from vllm.distributed.kv_transfer.kv_connector.utils import EngineId, TpKVTopology
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
CopyBlocksOp, CopyBlocksOp,
KVConnectorBase_V1, KVConnectorBase_V1,
@ -56,7 +56,6 @@ if TYPE_CHECKING:
from vllm.v1.request import Request from vllm.v1.request import Request
TransferHandle = int TransferHandle = int
EngineId = str
ReqId = str ReqId = str
# #
@ -873,9 +872,10 @@ class NixlConnectorWorker:
self.copy_blocks: CopyBlocksOp | None = None self.copy_blocks: CopyBlocksOp | None = None
# Map of engine_id -> kv_caches_base_addr. For TP case, each local # Map of engine_id -> kv_caches_base_addr. For TP case, each local
# rank will still only pull from a single remote TP worker.
self.kv_caches_base_addr: dict[EngineId, list[int]] = {}
self.device_id: int = 0 self.device_id: int = 0
# Current rank may pull from multiple remote TP workers.
# EngineId, dict[int, list[int]] -> engine_id, tp_rank, base_addr_for_layer
self.kv_caches_base_addr = defaultdict[EngineId, dict[int, list[int]]](dict)
# Number of NIXL regions. Currently one region per cache # Number of NIXL regions. Currently one region per cache
# (so 1 per layer for MLA, otherwise 2 per layer) # (so 1 per layer for MLA, otherwise 2 per layer)
@ -883,10 +883,12 @@ class NixlConnectorWorker:
self.num_layers = 0 self.num_layers = 0
# nixl_prepped_dlist_handle. # nixl_prepped_dlist_handle.
self.src_xfer_side_handle: int = 0 self.src_xfer_handles_by_block_size: dict[int, int] = {}
self.src_xfer_side_handles: dict[int, int] = {} # Populated dynamically during handshake based on remote configuration.
# Map of engine_id -> nixl_prepped_dlist_handle (int)]. # Keep track of regions at different tp_ratio values. tp_ratio->handles
self.dst_xfer_side_handles: dict[EngineId, int] = {} self.src_xfer_handles_by_tp_ratio: dict[int, list[int]] = {}
# Map of engine_id -> {tp_rank: nixl_prepped_dlist_handle (int)}.
self.dst_xfer_side_handles = defaultdict[EngineId, dict[int, int]](dict)
# Map of engine_id -> num_blocks. All ranks in the same deployment will # Map of engine_id -> num_blocks. All ranks in the same deployment will
# have the same number of blocks. # have the same number of blocks.
@ -977,103 +979,108 @@ class NixlConnectorWorker:
expected_engine_id: str, expected_engine_id: str,
) -> dict[int, str]: ) -> dict[int, str]:
"""Do a NIXL handshake with a remote instance.""" """Do a NIXL handshake with a remote instance."""
# When target instance TP > local TP, we need to perform multiple
start_time = time.perf_counter() # handshakes. Do it in a single background job for simplicity.
# Regardless, only handshake with the remote TP rank(s) that current
# NOTE(rob): we need each rank to have a unique port. This is # local rank will read from. Note that With homogeneous TP,
# a hack to keep us moving. We will switch when moving to etcd # this happens to be the same single rank_i.
# or where we have a single ZMQ socket in the scheduler. p_remote_ranks = self.kv_topo.get_target_remote_ranks(remote_tp_size)
remote_rank_to_agent_name = {}
# Handshake only with the remote TP rank that current local rank will
# pull from. With homogeneous TP it happens to be the same rank_i.
p_remote_rank = self.kv_topo.get_target_remote_rank(remote_tp_size)
path = make_zmq_path("tcp", host, port) path = make_zmq_path("tcp", host, port)
logger.debug(
"Querying metadata on path: %s at remote tp rank %s", path, p_remote_rank
)
# Send query for the request.
with zmq_ctx(zmq.REQ, path) as sock: with zmq_ctx(zmq.REQ, path) as sock:
msg = msgspec.msgpack.encode((GET_META_MSG, p_remote_rank)) for remote_rank in p_remote_ranks:
# Set receive timeout to 5 seconds to avoid hanging on dead server logger.debug(
sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds "Querying metadata on path: %s at remote tp rank %s",
sock.send(msg) path,
handshake_bytes = sock.recv() remote_rank,
# Decode handshake payload to get compatibility hash
handshake_decoder = msgspec.msgpack.Decoder(NixlHandshakePayload)
try:
handshake_payload = handshake_decoder.decode(handshake_bytes)
except (msgspec.DecodeError, msgspec.ValidationError) as e:
raise RuntimeError(
f"Failed to decode NixlHandshakePayload. This likely indicates "
f"an incompatibility between connector version. Error: {e}"
) from e
got_metadata_time = time.perf_counter()
logger.debug(
"NIXL handshake: get metadata took: %s", got_metadata_time - start_time
)
# Check compatibility hash BEFORE decoding agent metadata
if (
self.enforce_compat_hash
and handshake_payload.compatibility_hash != self.compat_hash
):
raise RuntimeError(
f"NIXL compatibility hash mismatch. "
f"Local: {self.compat_hash}, "
f"Remote: {handshake_payload.compatibility_hash}. "
f"Prefill and decode instances have incompatible configurations. "
f"This may be due to: different vLLM versions, models, dtypes, "
f"KV cache layouts, attention backends, etc. "
f"Both instances must use identical configurations."
f"Disable this check using "
f'--kv-transfer-config \'{{"kv_connector_extra_config": '
f'{{"enforce_handshake_compat": false}}}}\''
) )
logger.info( start_time = time.perf_counter()
"NIXL compatibility check passed (hash: %s)", # Send query for the request.
handshake_payload.compatibility_hash, msg = msgspec.msgpack.encode((GET_META_MSG, remote_rank))
) # Set receive timeout to 5 seconds to avoid hanging on dead server
sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds
sock.send(msg)
handshake_bytes = sock.recv()
# Decode agent metadata # Decode handshake payload to get compatibility hash
metadata_decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) handshake_decoder = msgspec.msgpack.Decoder(NixlHandshakePayload)
try: try:
metadata = metadata_decoder.decode( handshake_payload = handshake_decoder.decode(handshake_bytes)
handshake_payload.agent_metadata_bytes except (msgspec.DecodeError, msgspec.ValidationError) as e:
) raise RuntimeError(
except (msgspec.DecodeError, msgspec.ValidationError) as e: f"Failed to decode NixlHandshakePayload. This likely indicates "
# This should not happen if hash matched f"an incompatibility between connector version. Error: {e}"
raise RuntimeError( ) from e
f"Failed to decode NixlAgentMetadata. Error: {e}"
) from e
# Ensure engine id matches. got_metadata_time = time.perf_counter()
if metadata.engine_id != expected_engine_id: logger.debug(
raise RuntimeError( "NIXL handshake: get metadata took: %s",
f"Remote NIXL agent engine ID mismatch. " got_metadata_time - start_time,
f"Expected {expected_engine_id},"
f"received {metadata.engine_id}."
) )
# Register Remote agent. # Check compatibility hash BEFORE decoding agent metadata
assert metadata.block_size <= self.block_size, ( if (
"nP > nD is not supported yet." self.enforce_compat_hash
) and handshake_payload.compatibility_hash != self.compat_hash
remote_agent_name = self.add_remote_agent( ):
metadata, p_remote_rank, remote_tp_size raise RuntimeError(
) f"NIXL compatibility hash mismatch. "
f"Local: {self.compat_hash}, "
f"Remote: {handshake_payload.compatibility_hash}. "
f"Prefill and decode instances have incompatible "
f"configurations. This may be due to: different vLLM versions,"
f" models, dtypes, KV cache layouts, attention backends, etc. "
f"Both instances must use identical configurations."
f"Disable this check using "
f'--kv-transfer-config \'{{"kv_connector_extra_config": '
f'{{"enforce_handshake_compat": false}}}}\''
)
setup_agent_time = time.perf_counter() logger.info(
logger.debug( "NIXL compatibility check passed (hash: %s)",
"NIXL handshake: add agent took: %s", handshake_payload.compatibility_hash,
setup_agent_time - got_metadata_time, )
)
# Remote rank -> agent name. # Decode agent metadata
return {p_remote_rank: remote_agent_name} metadata_decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
try:
metadata = metadata_decoder.decode(
handshake_payload.agent_metadata_bytes
)
except (msgspec.DecodeError, msgspec.ValidationError) as e:
# This should not happen if hash matched
raise RuntimeError(
f"Failed to decode NixlAgentMetadata. Error: {e}"
) from e
# Ensure engine id matches.
if metadata.engine_id != expected_engine_id:
raise RuntimeError(
f"Remote NIXL agent engine ID mismatch. "
f"Expected {expected_engine_id},"
f"received {metadata.engine_id}."
)
# Ensure engine id matches.
if metadata.engine_id != expected_engine_id:
raise RuntimeError(
f"Remote NIXL agent engine ID mismatch. "
f"Expected {expected_engine_id},"
f"received {metadata.engine_id}."
)
setup_agent_time = time.perf_counter()
# Register Remote agent.
remote_agent_name = self.add_remote_agent(
metadata, remote_rank, remote_tp_size
)
logger.debug(
"NIXL handshake: add agent took: %s",
setup_agent_time - got_metadata_time,
)
remote_rank_to_agent_name[remote_rank] = remote_agent_name
return remote_rank_to_agent_name
def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> None: def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> None:
""" """
@ -1283,7 +1290,7 @@ class NixlConnectorWorker:
assert len(self.block_len_per_layer) == len(seen_base_addresses) assert len(self.block_len_per_layer) == len(seen_base_addresses)
assert self.num_blocks != 0 assert self.num_blocks != 0
self.kv_caches_base_addr[self.engine_id] = seen_base_addresses self.kv_caches_base_addr[self.engine_id][self.tp_rank] = seen_base_addresses
self.num_regions = len(caches_data) self.num_regions = len(caches_data)
self.num_layers = len(xfer_buffers.keys()) self.num_layers = len(xfer_buffers.keys())
@ -1310,9 +1317,9 @@ class NixlConnectorWorker:
# Register local/src descr for NIXL xfer. # Register local/src descr for NIXL xfer.
self.seen_base_addresses = seen_base_addresses self.seen_base_addresses = seen_base_addresses
self.src_xfer_side_handle = self.register_local_xfer_handler(self.block_size) self.src_xfer_handles_by_block_size[self.block_size], self.src_blocks_data = (
self.register_local_xfer_handler(self.block_size)
self.src_xfer_side_handles[self.block_size] = self.src_xfer_side_handle )
# TODO(mgoin): Hybrid memory allocator is currently disabled for # TODO(mgoin): Hybrid memory allocator is currently disabled for
# models with local attention (Llama 4). Can remove this once enabled. # models with local attention (Llama 4). Can remove this once enabled.
@ -1340,8 +1347,8 @@ class NixlConnectorWorker:
agent_metadata = NixlAgentMetadata( agent_metadata = NixlAgentMetadata(
engine_id=self.engine_id, engine_id=self.engine_id,
agent_metadata=self.nixl_wrapper.get_agent_metadata(), agent_metadata=self.nixl_wrapper.get_agent_metadata(),
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
device_id=self.device_id, device_id=self.device_id,
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id][self.tp_rank],
num_blocks=self.num_blocks, num_blocks=self.num_blocks,
block_lens=self.block_len_per_layer, block_lens=self.block_len_per_layer,
kv_cache_layout=self.kv_cache_layout kv_cache_layout=self.kv_cache_layout
@ -1359,7 +1366,7 @@ class NixlConnectorWorker:
def register_local_xfer_handler( def register_local_xfer_handler(
self, self,
block_size: int, block_size: int,
) -> int: ) -> tuple[int, list[tuple[int, int, int]]]:
""" """
Function used for register local xfer handler with local block_size or Function used for register local xfer handler with local block_size or
Remote block_size. Remote block_size.
@ -1407,7 +1414,7 @@ class NixlConnectorWorker:
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
# NIXL_INIT_AGENT to be used for preparations of local descs. # NIXL_INIT_AGENT to be used for preparations of local descs.
return self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs) return self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs), blocks_data
def add_remote_agent( def add_remote_agent(
self, self,
@ -1421,10 +1428,12 @@ class NixlConnectorWorker:
In particular, handle both homogeneous and heterogeneous TP. The former In particular, handle both homogeneous and heterogeneous TP. The former
requires local rank_i to read from remote rank_i. requires local rank_i to read from remote rank_i.
The latter, assuming D.world_size > P.world_size, requires that two or The latter, in the case of D.world_size < P.world_size, requires that a
more local TP worker share the xfer from a single TP worker. local (D) TP worker reads from multiple remote (P) TP workers.
Conversely, assuming D.world_size > P.world_size, two or more local TP
workers will read from a single remote TP worker.
Here's an example (non-MLA case): Here's an example for the last case described above (non-MLA):
rank_offset p_remote_tp_rank rank_offset p_remote_tp_rank
(kv split no) (kv split no)
@ -1474,9 +1483,6 @@ class NixlConnectorWorker:
nixl_agent_meta.agent_metadata nixl_agent_meta.agent_metadata
) )
# Handle tp_size>num_kv_heads: replicate KV cache.
replicates_kv_cache = self.kv_topo.replicates_kv_cache(engine_id)
# Create dst descs and xfer side handles. TP workers have same #blocks # Create dst descs and xfer side handles. TP workers have same #blocks
# so we only register once per engine_id. # so we only register once per engine_id.
# Example: # Example:
@ -1490,14 +1496,52 @@ class NixlConnectorWorker:
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
# Keep track of remote agent kv caches base addresses. # Keep track of remote agent kv caches base addresses.
self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr self.kv_caches_base_addr[engine_id][remote_tp_rank] = (
nixl_agent_meta.kv_caches_base_addr
)
self._validate_remote_agent_handshake(nixl_agent_meta, remote_tp_size) self._validate_remote_agent_handshake(nixl_agent_meta, remote_tp_size)
# Number of D TP workers reading from a single P TP worker. This is # This is 1 when P and D `--tensor-parallel-size` match. Otherwise,
# 1 when P and D `--tensor-parallel-size` match. # this is the ratio between the two sizes.
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(engine_id) tp_ratio = self.kv_topo.tp_ratio_from_engine_id(engine_id)
# Handle tp_size>num_kv_heads: replicate KV cache.
indexes_into_remote = (
not self.kv_topo.replicates_kv_cache(engine_id) and tp_ratio > 0
)
logger.debug(
"Registering remote agent (%s, rank %s) memory regions with tp_ratio %s",
engine_id,
remote_tp_rank,
tp_ratio,
)
### (Optional) Register local agent memory regions. MLA is not split.
if (
tp_ratio < 0
and not self.use_mla
and tp_ratio not in self.src_xfer_handles_by_tp_ratio
):
# Remote tp_size > local tp_size: read from multiple remote ranks.
# Logically "split" own regions into |tp_ratio| chunks. Mind that
# we only do this once per remote tp_size (replica-friendly).
self.src_xfer_handles_by_tp_ratio[tp_ratio] = []
for i in range(-tp_ratio):
blocks_data = []
for memory_region in self.src_blocks_data:
addr, local_block_len, own_tp_rank = memory_region
# Computing block len layer by layer allows for different
# block sizes to be used.
remote_block_len = local_block_len // (-tp_ratio)
addr = addr + i * remote_block_len
blocks_data.append((addr, remote_block_len, own_tp_rank))
descs = self.nixl_wrapper.get_xfer_descs(
blocks_data, self.nixl_memory_type
)
handle = self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs)
self.src_xfer_handles_by_tp_ratio[tp_ratio].append(handle)
### Register remote agent memory regions ### Register remote agent memory regions
blocks_data = [] blocks_data = []
# With homogeneous TP, D pulls the whole kv cache from corresponding # With homogeneous TP, D pulls the whole kv cache from corresponding
@ -1507,14 +1551,19 @@ class NixlConnectorWorker:
# Register all remote blocks, but only the corresponding kv heads. # Register all remote blocks, but only the corresponding kv heads.
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) # Read our whole local region size from remote.
remote_kv_block_len = kv_block_len // block_size_ratio local_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
remote_kv_block_len = local_block_len // block_size_ratio
if block_size_ratio > 1: if block_size_ratio > 1:
# using remote kv_block_len as transfer unit # using remote kv_block_len as transfer unit
kv_block_len = remote_kv_block_len local_block_len = remote_kv_block_len
if tp_ratio < 0 and not self.use_mla:
# Remote tp is bigger: read a chunk of local region from remote
local_block_len = local_block_len // (-tp_ratio)
rank_offset = ( rank_offset = (
self.tp_rank % tp_ratio * remote_kv_block_len self.tp_rank % tp_ratio * remote_kv_block_len
if not replicates_kv_cache if indexes_into_remote
else 0 else 0
) )
for block_id in range(nixl_agent_meta.num_blocks): for block_id in range(nixl_agent_meta.num_blocks):
@ -1524,7 +1573,7 @@ class NixlConnectorWorker:
# self.block_len == remote_block_len//tp_ratio bytes. # self.block_len == remote_block_len//tp_ratio bytes.
addr = base_addr + block_offset + rank_offset addr = base_addr + block_offset + rank_offset
# (addr, len, device id) # (addr, len, device id)
blocks_data.append((addr, kv_block_len, nixl_agent_meta.device_id)) blocks_data.append((addr, local_block_len, nixl_agent_meta.device_id))
if self.kv_topo.is_kv_layout_blocks_first: if self.kv_topo.is_kv_layout_blocks_first:
# With FlashInfer index V separately to allow head splitting. # With FlashInfer index V separately to allow head splitting.
@ -1533,7 +1582,7 @@ class NixlConnectorWorker:
addr = base_addr + block_offset + rank_offset addr = base_addr + block_offset + rank_offset
v_addr = addr + nixl_agent_meta.block_lens[i] // 2 v_addr = addr + nixl_agent_meta.block_lens[i] // 2
blocks_data.append( blocks_data.append(
(v_addr, kv_block_len, nixl_agent_meta.device_id) (v_addr, local_block_len, nixl_agent_meta.device_id)
) )
logger.debug( logger.debug(
@ -1546,15 +1595,15 @@ class NixlConnectorWorker:
# Register with NIXL. # Register with NIXL.
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
self.dst_xfer_side_handles[engine_id] = self.nixl_wrapper.prep_xfer_dlist( self.dst_xfer_side_handles[engine_id][remote_tp_rank] = (
remote_agent_name, descs self.nixl_wrapper.prep_xfer_dlist(remote_agent_name, descs)
) )
if block_size_ratio > 1: if block_size_ratio > 1:
# when prefill with smaller block_size, we need to init a # when prefill with smaller block_size, we need to init a
# new handler with same block_len to match # new handler with same block_len to match
self.src_xfer_side_handles[nixl_agent_meta.block_size] = ( self.src_xfer_handles_by_block_size[nixl_agent_meta.block_size] = (
self.register_local_xfer_handler(nixl_agent_meta.block_size) self.register_local_xfer_handler(nixl_agent_meta.block_size)[0]
) )
return remote_agent_name return remote_agent_name
@ -1574,7 +1623,9 @@ class NixlConnectorWorker:
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id( block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(
remote_engine_id remote_engine_id
) )
assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP" # Num kv_heads > tp_size and P TP > D TP case, not supported
assert not (tp_ratio < 0 and self.kv_topo.is_kv_replicated(remote_engine_id))
assert not self._use_pallas or tp_ratio == 1, ( assert not self._use_pallas or tp_ratio == 1, (
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet." "TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
) )
@ -1616,17 +1667,29 @@ class NixlConnectorWorker:
"All remote layers must have the same block size" "All remote layers must have the same block size"
) )
assert ( if tp_ratio > 0:
remote_block_len # Remote tp is smaller: remote block_len size is bigger
== (self.block_len_per_layer[0] * tp_ratio) // block_size_ratio assert (
), ( remote_block_len
"Remote P worker KV layer cache must be of shape [2, N, " == (self.block_len_per_layer[0] * tp_ratio) // block_size_ratio
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." ), (
) "Remote P worker KV layer cache must be of shape [2, N, "
"local_kv_heads*tp_ratio, page_size, head_dim] and same dtype."
) # noqa: E501
else:
assert block_size_ratio == 1, (
"Different local/remote block sizes are not supported when"
" P TP > D TP."
)
# Remote tp is bigger: remote block_len size is smaller
assert remote_block_len == self.block_len_per_layer[0] // (-tp_ratio), (
"Remote P worker KV layer cache must be of shape [2, N, "
"local_kv_heads/tp_ratio, page_size, head_dim] and same dtype."
) # noqa: E501
# TP workers have same #blocks. # TP workers that handhshake with same remote have same #blocks.
assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks
# Same number of regions/~layers.
assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer) assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer)
def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta): def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta):
@ -1710,7 +1773,7 @@ class NixlConnectorWorker:
) )
cache.index_copy_(0, indices, permuted_blocks) cache.index_copy_(0, indices, permuted_blocks)
def blocksize_post_process(self, block_ids_per_ratio: dict[float, list[list[int]]]): def blocksize_post_process(self, block_ids_per_ratio: dict[int, list[list[int]]]):
def _process_local_gt_remote(blocks_to_update, block_size_ratio): def _process_local_gt_remote(blocks_to_update, block_size_ratio):
n_kv_heads, block_size, head_size = blocks_to_update.shape[1:] n_kv_heads, block_size, head_size = blocks_to_update.shape[1:]
remote_block_size = block_size // block_size_ratio remote_block_size = block_size // block_size_ratio
@ -1840,7 +1903,7 @@ class NixlConnectorWorker:
notified_req_ids: set[str] = set() notified_req_ids: set[str] = set()
for notifs in self.nixl_wrapper.get_new_notifs().values(): for notifs in self.nixl_wrapper.get_new_notifs().values():
for notif in notifs: for notif in notifs:
req_id, tp_ratio = notif.decode("utf-8").rsplit(":", 1) req_id, tp_size = notif.decode("utf-8").rsplit(":", 1)
if ( if (
req_id not in self._reqs_to_send req_id not in self._reqs_to_send
and req_id not in self._reqs_to_process and req_id not in self._reqs_to_process
@ -1853,9 +1916,22 @@ class NixlConnectorWorker:
) )
continue continue
# NOTE: `tp_ratio` is the opposite when swapping local<>remote
n_consumers = int(tp_size)
tp_ratio = self.kv_topo.tp_ratio(n_consumers)
# Number of reads *per producer* to wait for.
# When remote D TP > local P TP we expect `tp_ratio` reads.
consumers_per_producer = (
-tp_ratio if n_consumers > self.world_size else 1
)
self.consumer_notification_counts_by_req[req_id] += 1 self.consumer_notification_counts_by_req[req_id] += 1
# Wait all consumers (D) to be done reading before freeing. # Wait all consumers (D) to be done reading before freeing.
if self.consumer_notification_counts_by_req[req_id] == int(tp_ratio): if (
self.consumer_notification_counts_by_req[req_id]
== consumers_per_producer
):
notified_req_ids.add(req_id) notified_req_ids.add(req_id)
del self.consumer_notification_counts_by_req[req_id] del self.consumer_notification_counts_by_req[req_id]
self._reqs_to_process.remove(req_id) self._reqs_to_process.remove(req_id)
@ -1872,7 +1948,7 @@ class NixlConnectorWorker:
""" """
done_req_ids: set[str] = set() done_req_ids: set[str] = set()
for req_id, handles in list(transfers.items()): for req_id, handles in list(transfers.items()):
in_progress = False in_progress = []
for handle in handles: for handle in handles:
try: try:
xfer_state = self.nixl_wrapper.check_xfer_state(handle) xfer_state = self.nixl_wrapper.check_xfer_state(handle)
@ -1882,7 +1958,7 @@ class NixlConnectorWorker:
self.xfer_stats.record_transfer(res) self.xfer_stats.record_transfer(res)
self.nixl_wrapper.release_xfer_handle(handle) self.nixl_wrapper.release_xfer_handle(handle)
elif xfer_state == "PROC": elif xfer_state == "PROC":
in_progress = True in_progress.append(handle)
continue continue
else: else:
logger.error( logger.error(
@ -1892,7 +1968,6 @@ class NixlConnectorWorker:
xfer_state, xfer_state,
) )
self._handle_failed_transfer(req_id, handle) self._handle_failed_transfer(req_id, handle)
in_progress = False
except Exception: except Exception:
logger.exception( logger.exception(
"NIXL transfer exception for request %s. " "NIXL transfer exception for request %s. "
@ -1900,11 +1975,13 @@ class NixlConnectorWorker:
req_id, req_id,
) )
self._handle_failed_transfer(req_id, handle) self._handle_failed_transfer(req_id, handle)
in_progress = False
if not in_progress: if not in_progress:
# Only report request as completed when all transfers are done.
done_req_ids.add(req_id) done_req_ids.add(req_id)
del transfers[req_id] del transfers[req_id]
else:
transfers[req_id] = in_progress
return done_req_ids return done_req_ids
def _handle_failed_transfer(self, req_id: str, handle: int): def _handle_failed_transfer(self, req_id: str, handle: int):
@ -1982,18 +2059,62 @@ class NixlConnectorWorker:
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
assert meta.remote is not None assert meta.remote is not None
logger.debug( remote_ranks = self.kv_topo.get_target_remote_ranks_from_engine_id(
"Remote agent %s available, calling _read_blocks for req %s", meta.remote.engine_id
meta.remote.engine_id,
req_id,
)
self._read_blocks(
request_id=req_id,
dst_engine_id=meta.remote.engine_id,
remote_request_id=meta.remote.request_id,
local_block_ids=meta.local_physical_block_ids,
remote_block_ids=meta.remote.block_ids,
) )
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(meta.remote.engine_id)
# D may have to perform multiple reads from different remote ranks.
for i, remote_rank in enumerate(remote_ranks):
if self.use_mla and tp_ratio < 0 and i > 0:
# MLA opt: when P TP > D TP, only a single read is executed for
# the first remote rank (cache is duplicated)..
break
remote_block_size = self.kv_topo.remote_block_size[meta.remote.engine_id]
logger.debug(
"Remote agent %s available, calling _read_blocks"
" on remote rank %s with remote block size %s for req %s",
meta.remote.engine_id,
remote_rank,
remote_block_size,
req_id,
)
# Get side handles.
if tp_ratio < 0 and not self.use_mla:
assert remote_block_size == self.block_size
# Remote tp_size > local tp_size: we must perform multiple
# reads. Get the memory chunk onto which we will write to.
local_xfer_side_handle = self.src_xfer_handles_by_tp_ratio[tp_ratio][i]
else:
# Single read from remote, we write to the whole memory region.
# Also handle remote block size different from local block size.
local_xfer_side_handle = self.src_xfer_handles_by_block_size[
remote_block_size
]
# Destination handle: remote_engine_id -> remote_rank -> handle.
remote_xfer_side_handle = self.dst_xfer_side_handles[meta.remote.engine_id][
remote_rank
]
self._read_blocks(
request_id=req_id,
dst_engine_id=meta.remote.engine_id,
remote_request_id=meta.remote.request_id,
local_block_ids=meta.local_physical_block_ids,
remote_block_ids=meta.remote.block_ids,
remote_rank=remote_rank,
local_xfer_side_handle=local_xfer_side_handle,
remote_xfer_side_handle=remote_xfer_side_handle,
)
if self.use_mla and tp_ratio < 0:
# ..but we still need to notify the other remote ranks that we
# have the blocks we need so they can update the request state.
notif_id = f"{req_id}:{self.world_size}".encode()
remote_agents = self._remote_agents[meta.remote.engine_id]
for rank_to_notify, agent in remote_agents.items():
if rank_to_notify != remote_rank:
self.nixl_wrapper.send_notif(agent, notif_msg=notif_id)
def _read_blocks( def _read_blocks(
self, self,
@ -2002,7 +2123,14 @@ class NixlConnectorWorker:
dst_engine_id: str, dst_engine_id: str,
request_id: str, request_id: str,
remote_request_id: str, remote_request_id: str,
remote_rank: int,
local_xfer_side_handle: int,
remote_xfer_side_handle: int,
): ):
"""
Post a READ point-to-point xfer request from a single local worker to
a single remote worker.
"""
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id) block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id)
if block_size_ratio > 1: if block_size_ratio > 1:
local_block_ids = self.get_mapped_blocks( local_block_ids = self.get_mapped_blocks(
@ -2031,18 +2159,14 @@ class NixlConnectorWorker:
# saturate IB with heterogeneous TP sizes. We should remove the staging # saturate IB with heterogeneous TP sizes. We should remove the staging
# blocks until we are ready. # blocks until we are ready.
# Number of D TP workers that will read from dst P. Propagate tp_ratio # Number of D TP workers that will read from dst P. Propagate info
# on notification so that dst worker can wait before freeing blocks. # on notification so that dst worker can wait before freeing blocks.
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(dst_engine_id) notif_id = f"{remote_request_id}:{self.world_size}".encode()
notif_id = f"{remote_request_id}:{tp_ratio}".encode()
# Full prefix cache hit: do not need to read remote blocks, # Full prefix cache hit: do not need to read remote blocks,
# just notify P worker that we have the blocks we need. # just notify P worker that we have the blocks we need.
num_local_blocks = len(local_block_ids) num_local_blocks = len(local_block_ids)
if num_local_blocks == 0: if num_local_blocks == 0:
remote_rank = self.kv_topo.get_target_remote_rank_from_engine_id(
dst_engine_id
)
agent_name = self._remote_agents[dst_engine_id][remote_rank] agent_name = self._remote_agents[dst_engine_id][remote_rank]
try: try:
self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id) self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id)
@ -2062,13 +2186,6 @@ class NixlConnectorWorker:
if num_local_blocks < num_remote_blocks: if num_local_blocks < num_remote_blocks:
remote_block_ids = remote_block_ids[-num_local_blocks:] remote_block_ids = remote_block_ids[-num_local_blocks:]
# Get side handles.
remote_block_size = self.kv_topo.remote_block_size[dst_engine_id]
local_xfer_side_handle = self.src_xfer_side_handles.get(
remote_block_size, self.src_xfer_side_handle
)
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id]
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from # NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
# corresponding rank. With heterogeneous TP, fixing D>P, the D tp # corresponding rank. With heterogeneous TP, fixing D>P, the D tp
# workers will issue xfers to parts of the P worker remote kv caches. # workers will issue xfers to parts of the P worker remote kv caches.
@ -2230,7 +2347,7 @@ class NixlConnectorWorker:
block_ids_np, self._physical_blocks_per_logical_kv_block, block_arange block_ids_np, self._physical_blocks_per_logical_kv_block, block_arange
).tolist() ).tolist()
def get_backend_aware_kv_block_len(self, layer_idx: int): def get_backend_aware_kv_block_len(self, layer_idx: int) -> int:
""" """
Get the block length for one K/V element (K and V have the same size). Get the block length for one K/V element (K and V have the same size).
@ -2276,11 +2393,16 @@ class NixlConnectorWorker:
for handle in handles: for handle in handles:
self.nixl_wrapper.release_xfer_handle(handle) self.nixl_wrapper.release_xfer_handle(handle)
self._recving_transfers.clear() self._recving_transfers.clear()
if self.src_xfer_side_handle: for handle in self.src_xfer_handles_by_block_size.values():
self.nixl_wrapper.release_dlist_handle(self.src_xfer_side_handle) self.nixl_wrapper.release_dlist_handle(handle)
self.src_xfer_side_handle = 0 self.src_xfer_handles_by_block_size.clear()
for dst_xfer_side_handle in self.dst_xfer_side_handles.values(): for handles in self.src_xfer_handles_by_tp_ratio.values():
self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle) for handle in handles:
self.nixl_wrapper.release_dlist_handle(handle)
self.src_xfer_handles_by_tp_ratio.clear()
for dst_xfer_side_handles in self.dst_xfer_side_handles.values():
for dst_xfer_side_handle in dst_xfer_side_handles.values():
self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle)
self.dst_xfer_side_handles.clear() self.dst_xfer_side_handles.clear()
for remote_agents in self._remote_agents.values(): for remote_agents in self._remote_agents.values():
for agent_name in remote_agents.values(): for agent_name in remote_agents.values():