mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-08 05:38:41 +08:00
[NIXL] Support P tensor-parallel-size > D tensor-parallel-size (#27274)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
fd8afdf38d
commit
bc3700e0cd
@ -8,9 +8,12 @@ SCRIPT="v1/kv_connector/nixl_integration/run_accuracy_test.sh"
|
||||
configs=(
|
||||
"GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2"
|
||||
"GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2"
|
||||
"GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1"
|
||||
"GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA case
|
||||
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
|
||||
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
|
||||
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP1, D-DPEP=2 (TP=1)
|
||||
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP2, D-DPEP=2 (TP=1)
|
||||
)
|
||||
|
||||
run_tests() {
|
||||
|
||||
@ -391,6 +391,8 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._hand_shake_latency = hand_shake_latency
|
||||
self.kv_cache_layout = kv_cache_layout
|
||||
# Mock register_kv_caches attribute needed for tests that do not call it.
|
||||
self.src_xfer_handles_by_block_size = {self.block_size: 1}
|
||||
|
||||
def _nixl_handshake(
|
||||
self, host: str, port: int, remote_tp_size: int, expected_engine_id: str
|
||||
@ -407,22 +409,43 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||
|
||||
assert expected_engine_id == self.REMOTE_ENGINE_ID
|
||||
|
||||
remote_agent_name = self.add_remote_agent(
|
||||
NixlAgentMetadata(
|
||||
engine_id=self.REMOTE_ENGINE_ID,
|
||||
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
||||
kv_caches_base_addr=[0],
|
||||
device_id=0,
|
||||
num_blocks=1,
|
||||
block_lens=self.block_len_per_layer,
|
||||
# `self.kv_cache_layout` is only forced to HND when vllm engine
|
||||
# is started. We mock HND here.
|
||||
kv_cache_layout="HND",
|
||||
block_size=self.block_size,
|
||||
),
|
||||
remote_tp_size=remote_tp_size,
|
||||
)
|
||||
return {0: remote_agent_name}
|
||||
# Adjust remote block length metadata to satisfy heterogeneous TP
|
||||
# invariants enforced during handshake validation.
|
||||
remote_block_lens = list(self.block_len_per_layer)
|
||||
tp_ratio = self.kv_topo.tp_ratio(remote_tp_size)
|
||||
if remote_tp_size > self.world_size:
|
||||
# P TP > D TP case, block_len of remote is smaller
|
||||
remote_block_lens = [
|
||||
block_len // (-tp_ratio) for block_len in remote_block_lens
|
||||
]
|
||||
elif remote_tp_size < self.world_size:
|
||||
remote_block_lens = [
|
||||
block_len * tp_ratio for block_len in remote_block_lens
|
||||
]
|
||||
|
||||
# When remote tp_size > local tp_size, handshake with multiple
|
||||
# remote ranks.
|
||||
num_hanshakes = 1 if tp_ratio > 0 else -tp_ratio
|
||||
remote_agents: dict[int, str] = {}
|
||||
for remote_tp_rank in range(num_hanshakes):
|
||||
remote_agent_name = self.add_remote_agent(
|
||||
NixlAgentMetadata(
|
||||
engine_id=self.REMOTE_ENGINE_ID,
|
||||
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
||||
kv_caches_base_addr=[0],
|
||||
device_id=remote_tp_rank,
|
||||
num_blocks=1,
|
||||
block_lens=remote_block_lens,
|
||||
# `self.kv_cache_layout` is only forced to HND when vllm engine
|
||||
# is started. We mock HND here.
|
||||
kv_cache_layout="HND",
|
||||
block_size=self.block_size,
|
||||
),
|
||||
remote_tp_rank=remote_tp_rank,
|
||||
remote_tp_size=remote_tp_size,
|
||||
)
|
||||
remote_agents[remote_tp_rank] = remote_agent_name
|
||||
return remote_agents
|
||||
|
||||
|
||||
class TestNixlHandshake:
|
||||
@ -453,7 +476,13 @@ class TestNixlHandshake:
|
||||
vllm_config, connector.engine_id, hand_shake_latency=0
|
||||
)
|
||||
assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper)
|
||||
connector.connector_worker.nixl_wrapper.set_cycles_before_xfer_done(3)
|
||||
worker = connector.connector_worker
|
||||
worker.nixl_wrapper.set_cycles_before_xfer_done(3)
|
||||
# simulate handshake
|
||||
worker.dst_xfer_side_handles = {
|
||||
FakeNixlConnectorWorker.REMOTE_ENGINE_ID: {0: 1}
|
||||
}
|
||||
worker.kv_cache_layout = "HND"
|
||||
num_xfers = 4
|
||||
while True:
|
||||
# For the same request_id, initiate multiple xfers across different
|
||||
@ -567,6 +596,171 @@ class TestNixlHandshake:
|
||||
return
|
||||
raise TimeoutError("Took too long to complete async handshake.")
|
||||
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FakeNixlWrapper,
|
||||
)
|
||||
@pytest.mark.parametrize("local_tp_size", [1, 2])
|
||||
def test_prefill_tp_size_greater_than_decode_tp_size(
|
||||
self, local_tp_size: int, dist_init
|
||||
):
|
||||
"""
|
||||
Verify remote TP > local TP handshake succeeds with different
|
||||
remote configurations.
|
||||
"""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
local_tp_size = 1
|
||||
vllm_config.parallel_config.tensor_parallel_size = local_tp_size
|
||||
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(
|
||||
vllm_config, connector.engine_id, hand_shake_latency=0
|
||||
)
|
||||
worker = connector.connector_worker
|
||||
|
||||
# Minimal local registration params used by add_remote_agent
|
||||
worker.slot_size_per_layer = [4096]
|
||||
worker.block_len_per_layer = [4096 * worker.block_size]
|
||||
worker.num_blocks = 1
|
||||
worker.dst_num_blocks[worker.engine_id] = worker.num_blocks
|
||||
worker.src_blocks_data = [(0, worker.block_len_per_layer[0], worker.tp_rank)]
|
||||
|
||||
def check_handshake(remote_tp_size: int):
|
||||
tp_ratio = remote_tp_size // local_tp_size
|
||||
assert set(remote_agents.keys()) == set(range(tp_ratio))
|
||||
|
||||
remote_engine_id = worker.REMOTE_ENGINE_ID
|
||||
assert worker._tp_size[remote_engine_id] == remote_tp_size
|
||||
assert -tp_ratio == worker.kv_topo.tp_ratio_from_engine_id(remote_engine_id)
|
||||
# ensure src_xfer_handles_by_tp_ratio is populated with tpratio chunks
|
||||
assert -tp_ratio in worker.src_xfer_handles_by_tp_ratio
|
||||
assert len(worker.src_xfer_handles_by_tp_ratio[-tp_ratio]) == tp_ratio
|
||||
assert remote_engine_id in worker.dst_xfer_side_handles
|
||||
assert set(worker.dst_xfer_side_handles[remote_engine_id].keys()) == set(
|
||||
range(tp_ratio)
|
||||
)
|
||||
|
||||
remote_agents = worker._nixl_handshake(
|
||||
host="localhost",
|
||||
port=1234,
|
||||
remote_tp_size=2,
|
||||
expected_engine_id=worker.REMOTE_ENGINE_ID,
|
||||
)
|
||||
check_handshake(2)
|
||||
|
||||
# NOTE flexiblity: a second remote with higher number of ranks is
|
||||
# discovered. This is not a scenario we actively support right now, but
|
||||
# the connector allows it.
|
||||
worker.REMOTE_ENGINE_ID = "remote_engine_2"
|
||||
remote_agents = worker._nixl_handshake(
|
||||
host="localhost",
|
||||
port=1234,
|
||||
remote_tp_size=6,
|
||||
expected_engine_id=worker.REMOTE_ENGINE_ID,
|
||||
)
|
||||
check_handshake(6)
|
||||
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FakeNixlWrapper,
|
||||
)
|
||||
@pytest.mark.parametrize("local_tp_size", [1, 2])
|
||||
def test_prefill_tp_size_greater_than_decode_tp_size_mla(
|
||||
self, local_tp_size: int, dist_init
|
||||
):
|
||||
"""
|
||||
Verify remote TP > local TP handshake succeeds with different
|
||||
remote configurations for an MLA model.
|
||||
"""
|
||||
vllm_config = create_vllm_config()
|
||||
d_tp_size = 1
|
||||
p_tp_size = 2
|
||||
|
||||
# Build two separate connectors/workers to emulate P TP=2 ranks.
|
||||
conn_p0 = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
conn_p1 = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
conn_p0.connector_worker = FakeNixlConnectorWorker(
|
||||
vllm_config, conn_p0.engine_id, hand_shake_latency=0
|
||||
)
|
||||
conn_p1.connector_worker = FakeNixlConnectorWorker(
|
||||
vllm_config, conn_p1.engine_id, hand_shake_latency=0
|
||||
)
|
||||
|
||||
# Force P world size to 2 for both workers and emulate distinct tp_ranks.
|
||||
# Also enable MLA path so that expected_finished_count is updated.
|
||||
for rank, worker in enumerate(
|
||||
(conn_p0.connector_worker, conn_p1.connector_worker)
|
||||
):
|
||||
worker.world_size = p_tp_size
|
||||
worker.kv_topo.remote_tp_size = {worker.engine_id: p_tp_size}
|
||||
worker.tp_rank = rank
|
||||
worker.use_mla = True
|
||||
|
||||
req_id = "req-ep-dp2-p0"
|
||||
now = time.perf_counter()
|
||||
# Register a request on P that is waiting for consumers to read
|
||||
# (both workers track it).
|
||||
conn_p0.connector_worker._reqs_to_send[req_id] = now + 10.0
|
||||
conn_p0.connector_worker._reqs_to_process.add(req_id)
|
||||
conn_p1.connector_worker._reqs_to_send[req_id] = now + 10.0
|
||||
conn_p1.connector_worker._reqs_to_process.add(req_id)
|
||||
|
||||
# Simulate a read notification coming from D with (tp=1, dp=2).
|
||||
notif = f"{req_id}:{d_tp_size}".encode()
|
||||
# D0-0->P0 notif
|
||||
conn_p0.connector_worker.nixl_wrapper.get_new_notifs = lambda: {
|
||||
"agent": [notif]
|
||||
} # type: ignore[method-assign]
|
||||
conn_p1.connector_worker.nixl_wrapper.get_new_notifs = lambda: {
|
||||
"agent": [notif]
|
||||
} # type: ignore[method-assign]
|
||||
|
||||
# Trigger notification processing via get_finished().
|
||||
done_sending0, _ = conn_p0.get_finished(finished_req_ids=set())
|
||||
done_sending1, _ = conn_p1.get_finished(finished_req_ids=set())
|
||||
assert req_id in done_sending0 and req_id in done_sending1
|
||||
|
||||
# E2E aggregation: ensure the aggregated output marks the request
|
||||
# as finished using the connector's expected_finished_count.
|
||||
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
||||
|
||||
aggregator = KVOutputAggregator.from_connector(conn_p0, world_size=2)
|
||||
|
||||
out0 = ModelRunnerOutput(
|
||||
req_ids=[req_id],
|
||||
req_id_to_index={req_id: 0},
|
||||
sampled_token_ids=[[0]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[None],
|
||||
kv_connector_output=KVConnectorOutput(
|
||||
finished_sending=done_sending0,
|
||||
finished_recving=None,
|
||||
),
|
||||
)
|
||||
out1 = ModelRunnerOutput(
|
||||
req_ids=[req_id],
|
||||
req_id_to_index={req_id: 0},
|
||||
sampled_token_ids=[[0]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[None],
|
||||
kv_connector_output=KVConnectorOutput(
|
||||
finished_sending=done_sending1,
|
||||
finished_recving=None,
|
||||
),
|
||||
)
|
||||
aggregated = aggregator.aggregate([out0, out1], output_rank=0)
|
||||
assert aggregated.kv_connector_output is not None
|
||||
assert aggregated.kv_connector_output.finished_sending == {req_id}
|
||||
|
||||
# Producers cleaned up state for the finished request.
|
||||
assert req_id not in conn_p0.connector_worker._reqs_to_send
|
||||
assert req_id not in conn_p0.connector_worker._reqs_to_process
|
||||
assert req_id not in conn_p1.connector_worker._reqs_to_send
|
||||
assert req_id not in conn_p1.connector_worker._reqs_to_process
|
||||
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FakeNixlWrapper,
|
||||
@ -585,6 +779,9 @@ class TestNixlHandshake:
|
||||
connector.connector_worker = FakeNixlConnectorWorker(
|
||||
vllm_config, connector.engine_id
|
||||
)
|
||||
# Register (mocked) local xfer handler
|
||||
# worker = connector.connector_worker
|
||||
# worker.src_xfer_handles_by_block_size = {worker.block_size: 1}
|
||||
metadata = NixlConnectorMetadata()
|
||||
total_reqs = 5
|
||||
for i in range(total_reqs):
|
||||
@ -672,7 +869,6 @@ class TestNixlHandshake:
|
||||
with pytest.raises(RuntimeError):
|
||||
# mismatched layout is expected to fail
|
||||
worker.add_remote_agent(meta, remote_tp_size=2)
|
||||
with pytest.raises(AssertionError):
|
||||
worker.add_remote_agent(meta, remote_tp_size=1)
|
||||
|
||||
@patch(
|
||||
@ -1357,8 +1553,11 @@ def test_shutdown_cleans_up_resources(dist_init):
|
||||
patch.object(nixl_wrapper, "deregister_memory") as mock_dereg,
|
||||
):
|
||||
worker._recving_transfers = {"req1": [123]}
|
||||
worker.src_xfer_side_handle = 456
|
||||
worker.dst_xfer_side_handles = {"engine1": 789}
|
||||
# Mock register_kv_cache which registers local handle
|
||||
worker.src_xfer_handles_by_block_size = {worker.block_size: 455}
|
||||
# P TP = 2 * D TP case, we should register 2 local handles
|
||||
worker.src_xfer_handles_by_tp_ratio = {-2: [456, 457]}
|
||||
worker.dst_xfer_side_handles = {"engine1": {0: 789}}
|
||||
worker._remote_agents = {"engine1": {0: "agent1"}}
|
||||
worker._registered_descs = ["desc1", "desc2"]
|
||||
|
||||
@ -1379,8 +1578,10 @@ def test_shutdown_cleans_up_resources(dist_init):
|
||||
mock_listener.join.assert_called_once()
|
||||
|
||||
mock_rel_xfer.assert_called_once_with(123)
|
||||
assert mock_rel_dlist.call_count == 2
|
||||
mock_rel_dlist.assert_any_call(456) # src handle
|
||||
assert mock_rel_dlist.call_count == 4
|
||||
mock_rel_dlist.assert_any_call(455) # src handle (whole region)
|
||||
mock_rel_dlist.assert_any_call(456) # src handle (1st chunk)
|
||||
mock_rel_dlist.assert_any_call(457) # src handle (2nd chunk)
|
||||
mock_rel_dlist.assert_any_call(789) # dst handle
|
||||
mock_rem_agent.assert_called_once_with("agent1")
|
||||
assert mock_dereg.call_count == 2
|
||||
|
||||
@ -21,6 +21,8 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
EngineId = str
|
||||
|
||||
|
||||
def get_kv_connector_cache_layout():
|
||||
# NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is
|
||||
@ -209,12 +211,12 @@ class TpKVTopology:
|
||||
"""
|
||||
|
||||
tp_rank: int
|
||||
remote_tp_size: dict[str, int]
|
||||
remote_tp_size: dict[EngineId, int]
|
||||
is_mla: bool
|
||||
total_num_kv_heads: int
|
||||
attn_backend: type[AttentionBackend]
|
||||
engine_id: str
|
||||
remote_block_size: dict[str, int]
|
||||
engine_id: EngineId
|
||||
remote_block_size: dict[EngineId, int]
|
||||
|
||||
def __post_init__(self):
|
||||
# Figure out whether the first dimension of the cache is K/V
|
||||
@ -256,18 +258,28 @@ class TpKVTopology:
|
||||
Calculate the tensor parallel ratio between local and remote TP.
|
||||
We can think of it as the number of local TP workers-per-remote TP
|
||||
workers. Local workers will read from the same remote TP worker in
|
||||
groups of size `tp_ratio`.
|
||||
groups of size `tp_ratio`.If remote tp_size > local tp_size, the
|
||||
ratio is flipped (remote_size/local_size) and the returned value is
|
||||
negative.
|
||||
"""
|
||||
assert self.tp_size % remote_tp_size == 0, (
|
||||
f"Local tensor parallel size {self.tp_size} is not divisible "
|
||||
f"by remote tensor parallel size {remote_tp_size}."
|
||||
if self.tp_size >= remote_tp_size:
|
||||
assert self.tp_size % remote_tp_size == 0, (
|
||||
f"Local tensor parallel size {self.tp_size} is not divisible "
|
||||
f"by remote tensor parallel size {remote_tp_size}."
|
||||
)
|
||||
return self.tp_size // remote_tp_size
|
||||
|
||||
assert remote_tp_size % self.tp_size == 0, (
|
||||
f"Remote tensor parallel size {remote_tp_size} is not divisible "
|
||||
f"by local tensor parallel size {self.tp_size}."
|
||||
)
|
||||
return self.tp_size // remote_tp_size
|
||||
# P TP > D TP case, return the ratio as negative
|
||||
return -remote_tp_size // self.tp_size
|
||||
|
||||
def block_size_ratio(
|
||||
self,
|
||||
remote_block_size: int,
|
||||
) -> float:
|
||||
) -> int:
|
||||
"""
|
||||
Calculate the block size ratio between local and remote TP.
|
||||
"""
|
||||
@ -279,19 +291,19 @@ class TpKVTopology:
|
||||
|
||||
def tp_ratio_from_engine_id(
|
||||
self,
|
||||
remote_engine_id: str,
|
||||
remote_engine_id: EngineId,
|
||||
) -> int:
|
||||
remote_tp_size = self.remote_tp_size[remote_engine_id]
|
||||
return self.tp_ratio(remote_tp_size)
|
||||
|
||||
def block_size_ratio_from_engine_id(
|
||||
self,
|
||||
remote_engine_id: str,
|
||||
) -> float:
|
||||
remote_engine_id: EngineId,
|
||||
) -> int:
|
||||
remote_block_size = self.remote_block_size[remote_engine_id]
|
||||
return self.block_size_ratio(remote_block_size)
|
||||
|
||||
def is_kv_replicated(self, engine_id: str) -> bool:
|
||||
def is_kv_replicated(self, engine_id: EngineId) -> bool:
|
||||
"""
|
||||
Whether the KV cache is replicated across TP workers due to the
|
||||
number of TP workers being greater than the number of KV heads.
|
||||
@ -299,24 +311,30 @@ class TpKVTopology:
|
||||
tp_size = self.remote_tp_size[engine_id]
|
||||
return tp_size // self.total_num_kv_heads >= 1
|
||||
|
||||
def replicates_kv_cache(self, remote_engine_id: str) -> bool:
|
||||
def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool:
|
||||
# MLA is always replicated as the hidden dim can't be split.
|
||||
return self.is_mla or self.is_kv_replicated(remote_engine_id)
|
||||
|
||||
def get_target_remote_rank(
|
||||
def get_target_remote_ranks(
|
||||
self,
|
||||
remote_tp_size: int,
|
||||
) -> int:
|
||||
) -> list[int]:
|
||||
"""
|
||||
Get the remote TP rank (on P) that the current local TP rank
|
||||
(on D) will read from.
|
||||
(on D) will read from. When remote tp_size > local tp_size, we
|
||||
read from multiple remote ranks.
|
||||
"""
|
||||
tp_ratio = self.tp_ratio(remote_tp_size)
|
||||
return self.tp_rank // tp_ratio
|
||||
if tp_ratio > 0:
|
||||
return [self.tp_rank // tp_ratio]
|
||||
|
||||
def get_target_remote_rank_from_engine_id(
|
||||
# P TP > D TP case, D reads from |tp_ratio| remote workers.
|
||||
tp_ratio = -tp_ratio
|
||||
return [self.tp_rank * tp_ratio + i for i in range(tp_ratio)]
|
||||
|
||||
def get_target_remote_ranks_from_engine_id(
|
||||
self,
|
||||
remote_engine_id: str,
|
||||
) -> int:
|
||||
remote_engine_id: EngineId,
|
||||
) -> list[int]:
|
||||
remote_tp_size = self.remote_tp_size[remote_engine_id]
|
||||
return self.get_target_remote_rank(remote_tp_size)
|
||||
return self.get_target_remote_ranks(remote_tp_size)
|
||||
|
||||
@ -23,7 +23,7 @@ from vllm import envs
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import TpKVTopology
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import EngineId, TpKVTopology
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
CopyBlocksOp,
|
||||
KVConnectorBase_V1,
|
||||
@ -56,7 +56,6 @@ if TYPE_CHECKING:
|
||||
from vllm.v1.request import Request
|
||||
|
||||
TransferHandle = int
|
||||
EngineId = str
|
||||
ReqId = str
|
||||
|
||||
#
|
||||
@ -873,9 +872,10 @@ class NixlConnectorWorker:
|
||||
self.copy_blocks: CopyBlocksOp | None = None
|
||||
|
||||
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
|
||||
# rank will still only pull from a single remote TP worker.
|
||||
self.kv_caches_base_addr: dict[EngineId, list[int]] = {}
|
||||
self.device_id: int = 0
|
||||
# Current rank may pull from multiple remote TP workers.
|
||||
# EngineId, dict[int, list[int]] -> engine_id, tp_rank, base_addr_for_layer
|
||||
self.kv_caches_base_addr = defaultdict[EngineId, dict[int, list[int]]](dict)
|
||||
|
||||
# Number of NIXL regions. Currently one region per cache
|
||||
# (so 1 per layer for MLA, otherwise 2 per layer)
|
||||
@ -883,10 +883,12 @@ class NixlConnectorWorker:
|
||||
self.num_layers = 0
|
||||
|
||||
# nixl_prepped_dlist_handle.
|
||||
self.src_xfer_side_handle: int = 0
|
||||
self.src_xfer_side_handles: dict[int, int] = {}
|
||||
# Map of engine_id -> nixl_prepped_dlist_handle (int)].
|
||||
self.dst_xfer_side_handles: dict[EngineId, int] = {}
|
||||
self.src_xfer_handles_by_block_size: dict[int, int] = {}
|
||||
# Populated dynamically during handshake based on remote configuration.
|
||||
# Keep track of regions at different tp_ratio values. tp_ratio->handles
|
||||
self.src_xfer_handles_by_tp_ratio: dict[int, list[int]] = {}
|
||||
# Map of engine_id -> {tp_rank: nixl_prepped_dlist_handle (int)}.
|
||||
self.dst_xfer_side_handles = defaultdict[EngineId, dict[int, int]](dict)
|
||||
|
||||
# Map of engine_id -> num_blocks. All ranks in the same deployment will
|
||||
# have the same number of blocks.
|
||||
@ -977,103 +979,108 @@ class NixlConnectorWorker:
|
||||
expected_engine_id: str,
|
||||
) -> dict[int, str]:
|
||||
"""Do a NIXL handshake with a remote instance."""
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# NOTE(rob): we need each rank to have a unique port. This is
|
||||
# a hack to keep us moving. We will switch when moving to etcd
|
||||
# or where we have a single ZMQ socket in the scheduler.
|
||||
|
||||
# Handshake only with the remote TP rank that current local rank will
|
||||
# pull from. With homogeneous TP it happens to be the same rank_i.
|
||||
p_remote_rank = self.kv_topo.get_target_remote_rank(remote_tp_size)
|
||||
# When target instance TP > local TP, we need to perform multiple
|
||||
# handshakes. Do it in a single background job for simplicity.
|
||||
# Regardless, only handshake with the remote TP rank(s) that current
|
||||
# local rank will read from. Note that With homogeneous TP,
|
||||
# this happens to be the same single rank_i.
|
||||
p_remote_ranks = self.kv_topo.get_target_remote_ranks(remote_tp_size)
|
||||
remote_rank_to_agent_name = {}
|
||||
path = make_zmq_path("tcp", host, port)
|
||||
logger.debug(
|
||||
"Querying metadata on path: %s at remote tp rank %s", path, p_remote_rank
|
||||
)
|
||||
|
||||
# Send query for the request.
|
||||
with zmq_ctx(zmq.REQ, path) as sock:
|
||||
msg = msgspec.msgpack.encode((GET_META_MSG, p_remote_rank))
|
||||
# Set receive timeout to 5 seconds to avoid hanging on dead server
|
||||
sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds
|
||||
sock.send(msg)
|
||||
handshake_bytes = sock.recv()
|
||||
|
||||
# Decode handshake payload to get compatibility hash
|
||||
handshake_decoder = msgspec.msgpack.Decoder(NixlHandshakePayload)
|
||||
try:
|
||||
handshake_payload = handshake_decoder.decode(handshake_bytes)
|
||||
except (msgspec.DecodeError, msgspec.ValidationError) as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to decode NixlHandshakePayload. This likely indicates "
|
||||
f"an incompatibility between connector version. Error: {e}"
|
||||
) from e
|
||||
|
||||
got_metadata_time = time.perf_counter()
|
||||
logger.debug(
|
||||
"NIXL handshake: get metadata took: %s", got_metadata_time - start_time
|
||||
)
|
||||
|
||||
# Check compatibility hash BEFORE decoding agent metadata
|
||||
if (
|
||||
self.enforce_compat_hash
|
||||
and handshake_payload.compatibility_hash != self.compat_hash
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"NIXL compatibility hash mismatch. "
|
||||
f"Local: {self.compat_hash}, "
|
||||
f"Remote: {handshake_payload.compatibility_hash}. "
|
||||
f"Prefill and decode instances have incompatible configurations. "
|
||||
f"This may be due to: different vLLM versions, models, dtypes, "
|
||||
f"KV cache layouts, attention backends, etc. "
|
||||
f"Both instances must use identical configurations."
|
||||
f"Disable this check using "
|
||||
f'--kv-transfer-config \'{{"kv_connector_extra_config": '
|
||||
f'{{"enforce_handshake_compat": false}}}}\''
|
||||
for remote_rank in p_remote_ranks:
|
||||
logger.debug(
|
||||
"Querying metadata on path: %s at remote tp rank %s",
|
||||
path,
|
||||
remote_rank,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"NIXL compatibility check passed (hash: %s)",
|
||||
handshake_payload.compatibility_hash,
|
||||
)
|
||||
start_time = time.perf_counter()
|
||||
# Send query for the request.
|
||||
msg = msgspec.msgpack.encode((GET_META_MSG, remote_rank))
|
||||
# Set receive timeout to 5 seconds to avoid hanging on dead server
|
||||
sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds
|
||||
sock.send(msg)
|
||||
handshake_bytes = sock.recv()
|
||||
|
||||
# Decode agent metadata
|
||||
metadata_decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
|
||||
try:
|
||||
metadata = metadata_decoder.decode(
|
||||
handshake_payload.agent_metadata_bytes
|
||||
)
|
||||
except (msgspec.DecodeError, msgspec.ValidationError) as e:
|
||||
# This should not happen if hash matched
|
||||
raise RuntimeError(
|
||||
f"Failed to decode NixlAgentMetadata. Error: {e}"
|
||||
) from e
|
||||
# Decode handshake payload to get compatibility hash
|
||||
handshake_decoder = msgspec.msgpack.Decoder(NixlHandshakePayload)
|
||||
try:
|
||||
handshake_payload = handshake_decoder.decode(handshake_bytes)
|
||||
except (msgspec.DecodeError, msgspec.ValidationError) as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to decode NixlHandshakePayload. This likely indicates "
|
||||
f"an incompatibility between connector version. Error: {e}"
|
||||
) from e
|
||||
|
||||
# Ensure engine id matches.
|
||||
if metadata.engine_id != expected_engine_id:
|
||||
raise RuntimeError(
|
||||
f"Remote NIXL agent engine ID mismatch. "
|
||||
f"Expected {expected_engine_id},"
|
||||
f"received {metadata.engine_id}."
|
||||
got_metadata_time = time.perf_counter()
|
||||
logger.debug(
|
||||
"NIXL handshake: get metadata took: %s",
|
||||
got_metadata_time - start_time,
|
||||
)
|
||||
|
||||
# Register Remote agent.
|
||||
assert metadata.block_size <= self.block_size, (
|
||||
"nP > nD is not supported yet."
|
||||
)
|
||||
remote_agent_name = self.add_remote_agent(
|
||||
metadata, p_remote_rank, remote_tp_size
|
||||
)
|
||||
# Check compatibility hash BEFORE decoding agent metadata
|
||||
if (
|
||||
self.enforce_compat_hash
|
||||
and handshake_payload.compatibility_hash != self.compat_hash
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"NIXL compatibility hash mismatch. "
|
||||
f"Local: {self.compat_hash}, "
|
||||
f"Remote: {handshake_payload.compatibility_hash}. "
|
||||
f"Prefill and decode instances have incompatible "
|
||||
f"configurations. This may be due to: different vLLM versions,"
|
||||
f" models, dtypes, KV cache layouts, attention backends, etc. "
|
||||
f"Both instances must use identical configurations."
|
||||
f"Disable this check using "
|
||||
f'--kv-transfer-config \'{{"kv_connector_extra_config": '
|
||||
f'{{"enforce_handshake_compat": false}}}}\''
|
||||
)
|
||||
|
||||
setup_agent_time = time.perf_counter()
|
||||
logger.debug(
|
||||
"NIXL handshake: add agent took: %s",
|
||||
setup_agent_time - got_metadata_time,
|
||||
)
|
||||
logger.info(
|
||||
"NIXL compatibility check passed (hash: %s)",
|
||||
handshake_payload.compatibility_hash,
|
||||
)
|
||||
|
||||
# Remote rank -> agent name.
|
||||
return {p_remote_rank: remote_agent_name}
|
||||
# Decode agent metadata
|
||||
metadata_decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
|
||||
try:
|
||||
metadata = metadata_decoder.decode(
|
||||
handshake_payload.agent_metadata_bytes
|
||||
)
|
||||
except (msgspec.DecodeError, msgspec.ValidationError) as e:
|
||||
# This should not happen if hash matched
|
||||
raise RuntimeError(
|
||||
f"Failed to decode NixlAgentMetadata. Error: {e}"
|
||||
) from e
|
||||
|
||||
# Ensure engine id matches.
|
||||
if metadata.engine_id != expected_engine_id:
|
||||
raise RuntimeError(
|
||||
f"Remote NIXL agent engine ID mismatch. "
|
||||
f"Expected {expected_engine_id},"
|
||||
f"received {metadata.engine_id}."
|
||||
)
|
||||
# Ensure engine id matches.
|
||||
if metadata.engine_id != expected_engine_id:
|
||||
raise RuntimeError(
|
||||
f"Remote NIXL agent engine ID mismatch. "
|
||||
f"Expected {expected_engine_id},"
|
||||
f"received {metadata.engine_id}."
|
||||
)
|
||||
setup_agent_time = time.perf_counter()
|
||||
|
||||
# Register Remote agent.
|
||||
remote_agent_name = self.add_remote_agent(
|
||||
metadata, remote_rank, remote_tp_size
|
||||
)
|
||||
logger.debug(
|
||||
"NIXL handshake: add agent took: %s",
|
||||
setup_agent_time - got_metadata_time,
|
||||
)
|
||||
remote_rank_to_agent_name[remote_rank] = remote_agent_name
|
||||
return remote_rank_to_agent_name
|
||||
|
||||
def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> None:
|
||||
"""
|
||||
@ -1283,7 +1290,7 @@ class NixlConnectorWorker:
|
||||
assert len(self.block_len_per_layer) == len(seen_base_addresses)
|
||||
assert self.num_blocks != 0
|
||||
|
||||
self.kv_caches_base_addr[self.engine_id] = seen_base_addresses
|
||||
self.kv_caches_base_addr[self.engine_id][self.tp_rank] = seen_base_addresses
|
||||
self.num_regions = len(caches_data)
|
||||
self.num_layers = len(xfer_buffers.keys())
|
||||
|
||||
@ -1310,9 +1317,9 @@ class NixlConnectorWorker:
|
||||
|
||||
# Register local/src descr for NIXL xfer.
|
||||
self.seen_base_addresses = seen_base_addresses
|
||||
self.src_xfer_side_handle = self.register_local_xfer_handler(self.block_size)
|
||||
|
||||
self.src_xfer_side_handles[self.block_size] = self.src_xfer_side_handle
|
||||
self.src_xfer_handles_by_block_size[self.block_size], self.src_blocks_data = (
|
||||
self.register_local_xfer_handler(self.block_size)
|
||||
)
|
||||
|
||||
# TODO(mgoin): Hybrid memory allocator is currently disabled for
|
||||
# models with local attention (Llama 4). Can remove this once enabled.
|
||||
@ -1340,8 +1347,8 @@ class NixlConnectorWorker:
|
||||
agent_metadata = NixlAgentMetadata(
|
||||
engine_id=self.engine_id,
|
||||
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
|
||||
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
|
||||
device_id=self.device_id,
|
||||
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id][self.tp_rank],
|
||||
num_blocks=self.num_blocks,
|
||||
block_lens=self.block_len_per_layer,
|
||||
kv_cache_layout=self.kv_cache_layout
|
||||
@ -1359,7 +1366,7 @@ class NixlConnectorWorker:
|
||||
def register_local_xfer_handler(
|
||||
self,
|
||||
block_size: int,
|
||||
) -> int:
|
||||
) -> tuple[int, list[tuple[int, int, int]]]:
|
||||
"""
|
||||
Function used for register local xfer handler with local block_size or
|
||||
Remote block_size.
|
||||
@ -1407,7 +1414,7 @@ class NixlConnectorWorker:
|
||||
|
||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
|
||||
# NIXL_INIT_AGENT to be used for preparations of local descs.
|
||||
return self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs)
|
||||
return self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs), blocks_data
|
||||
|
||||
def add_remote_agent(
|
||||
self,
|
||||
@ -1421,10 +1428,12 @@ class NixlConnectorWorker:
|
||||
|
||||
In particular, handle both homogeneous and heterogeneous TP. The former
|
||||
requires local rank_i to read from remote rank_i.
|
||||
The latter, assuming D.world_size > P.world_size, requires that two or
|
||||
more local TP worker share the xfer from a single TP worker.
|
||||
The latter, in the case of D.world_size < P.world_size, requires that a
|
||||
local (D) TP worker reads from multiple remote (P) TP workers.
|
||||
Conversely, assuming D.world_size > P.world_size, two or more local TP
|
||||
workers will read from a single remote TP worker.
|
||||
|
||||
Here's an example (non-MLA case):
|
||||
Here's an example for the last case described above (non-MLA):
|
||||
|
||||
rank_offset p_remote_tp_rank
|
||||
(kv split no)
|
||||
@ -1474,9 +1483,6 @@ class NixlConnectorWorker:
|
||||
nixl_agent_meta.agent_metadata
|
||||
)
|
||||
|
||||
# Handle tp_size>num_kv_heads: replicate KV cache.
|
||||
replicates_kv_cache = self.kv_topo.replicates_kv_cache(engine_id)
|
||||
|
||||
# Create dst descs and xfer side handles. TP workers have same #blocks
|
||||
# so we only register once per engine_id.
|
||||
# Example:
|
||||
@ -1490,14 +1496,52 @@ class NixlConnectorWorker:
|
||||
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
|
||||
|
||||
# Keep track of remote agent kv caches base addresses.
|
||||
self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr
|
||||
|
||||
self.kv_caches_base_addr[engine_id][remote_tp_rank] = (
|
||||
nixl_agent_meta.kv_caches_base_addr
|
||||
)
|
||||
self._validate_remote_agent_handshake(nixl_agent_meta, remote_tp_size)
|
||||
|
||||
# Number of D TP workers reading from a single P TP worker. This is
|
||||
# 1 when P and D `--tensor-parallel-size` match.
|
||||
# This is 1 when P and D `--tensor-parallel-size` match. Otherwise,
|
||||
# this is the ratio between the two sizes.
|
||||
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(engine_id)
|
||||
|
||||
# Handle tp_size>num_kv_heads: replicate KV cache.
|
||||
indexes_into_remote = (
|
||||
not self.kv_topo.replicates_kv_cache(engine_id) and tp_ratio > 0
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Registering remote agent (%s, rank %s) memory regions with tp_ratio %s",
|
||||
engine_id,
|
||||
remote_tp_rank,
|
||||
tp_ratio,
|
||||
)
|
||||
|
||||
### (Optional) Register local agent memory regions. MLA is not split.
|
||||
if (
|
||||
tp_ratio < 0
|
||||
and not self.use_mla
|
||||
and tp_ratio not in self.src_xfer_handles_by_tp_ratio
|
||||
):
|
||||
# Remote tp_size > local tp_size: read from multiple remote ranks.
|
||||
# Logically "split" own regions into |tp_ratio| chunks. Mind that
|
||||
# we only do this once per remote tp_size (replica-friendly).
|
||||
self.src_xfer_handles_by_tp_ratio[tp_ratio] = []
|
||||
for i in range(-tp_ratio):
|
||||
blocks_data = []
|
||||
for memory_region in self.src_blocks_data:
|
||||
addr, local_block_len, own_tp_rank = memory_region
|
||||
# Computing block len layer by layer allows for different
|
||||
# block sizes to be used.
|
||||
remote_block_len = local_block_len // (-tp_ratio)
|
||||
addr = addr + i * remote_block_len
|
||||
blocks_data.append((addr, remote_block_len, own_tp_rank))
|
||||
descs = self.nixl_wrapper.get_xfer_descs(
|
||||
blocks_data, self.nixl_memory_type
|
||||
)
|
||||
handle = self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs)
|
||||
self.src_xfer_handles_by_tp_ratio[tp_ratio].append(handle)
|
||||
|
||||
### Register remote agent memory regions
|
||||
blocks_data = []
|
||||
# With homogeneous TP, D pulls the whole kv cache from corresponding
|
||||
@ -1507,14 +1551,19 @@ class NixlConnectorWorker:
|
||||
|
||||
# Register all remote blocks, but only the corresponding kv heads.
|
||||
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
|
||||
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
|
||||
remote_kv_block_len = kv_block_len // block_size_ratio
|
||||
# Read our whole local region size from remote.
|
||||
local_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
|
||||
remote_kv_block_len = local_block_len // block_size_ratio
|
||||
if block_size_ratio > 1:
|
||||
# using remote kv_block_len as transfer unit
|
||||
kv_block_len = remote_kv_block_len
|
||||
local_block_len = remote_kv_block_len
|
||||
|
||||
if tp_ratio < 0 and not self.use_mla:
|
||||
# Remote tp is bigger: read a chunk of local region from remote
|
||||
local_block_len = local_block_len // (-tp_ratio)
|
||||
rank_offset = (
|
||||
self.tp_rank % tp_ratio * remote_kv_block_len
|
||||
if not replicates_kv_cache
|
||||
if indexes_into_remote
|
||||
else 0
|
||||
)
|
||||
for block_id in range(nixl_agent_meta.num_blocks):
|
||||
@ -1524,7 +1573,7 @@ class NixlConnectorWorker:
|
||||
# self.block_len == remote_block_len//tp_ratio bytes.
|
||||
addr = base_addr + block_offset + rank_offset
|
||||
# (addr, len, device id)
|
||||
blocks_data.append((addr, kv_block_len, nixl_agent_meta.device_id))
|
||||
blocks_data.append((addr, local_block_len, nixl_agent_meta.device_id))
|
||||
|
||||
if self.kv_topo.is_kv_layout_blocks_first:
|
||||
# With FlashInfer index V separately to allow head splitting.
|
||||
@ -1533,7 +1582,7 @@ class NixlConnectorWorker:
|
||||
addr = base_addr + block_offset + rank_offset
|
||||
v_addr = addr + nixl_agent_meta.block_lens[i] // 2
|
||||
blocks_data.append(
|
||||
(v_addr, kv_block_len, nixl_agent_meta.device_id)
|
||||
(v_addr, local_block_len, nixl_agent_meta.device_id)
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
@ -1546,15 +1595,15 @@ class NixlConnectorWorker:
|
||||
|
||||
# Register with NIXL.
|
||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
|
||||
self.dst_xfer_side_handles[engine_id] = self.nixl_wrapper.prep_xfer_dlist(
|
||||
remote_agent_name, descs
|
||||
self.dst_xfer_side_handles[engine_id][remote_tp_rank] = (
|
||||
self.nixl_wrapper.prep_xfer_dlist(remote_agent_name, descs)
|
||||
)
|
||||
|
||||
if block_size_ratio > 1:
|
||||
# when prefill with smaller block_size, we need to init a
|
||||
# new handler with same block_len to match
|
||||
self.src_xfer_side_handles[nixl_agent_meta.block_size] = (
|
||||
self.register_local_xfer_handler(nixl_agent_meta.block_size)
|
||||
self.src_xfer_handles_by_block_size[nixl_agent_meta.block_size] = (
|
||||
self.register_local_xfer_handler(nixl_agent_meta.block_size)[0]
|
||||
)
|
||||
|
||||
return remote_agent_name
|
||||
@ -1574,7 +1623,9 @@ class NixlConnectorWorker:
|
||||
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(
|
||||
remote_engine_id
|
||||
)
|
||||
assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP"
|
||||
# Num kv_heads > tp_size and P TP > D TP case, not supported
|
||||
assert not (tp_ratio < 0 and self.kv_topo.is_kv_replicated(remote_engine_id))
|
||||
|
||||
assert not self._use_pallas or tp_ratio == 1, (
|
||||
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
|
||||
)
|
||||
@ -1616,17 +1667,29 @@ class NixlConnectorWorker:
|
||||
"All remote layers must have the same block size"
|
||||
)
|
||||
|
||||
assert (
|
||||
remote_block_len
|
||||
== (self.block_len_per_layer[0] * tp_ratio) // block_size_ratio
|
||||
), (
|
||||
"Remote P worker KV layer cache must be of shape [2, N, "
|
||||
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
|
||||
)
|
||||
if tp_ratio > 0:
|
||||
# Remote tp is smaller: remote block_len size is bigger
|
||||
assert (
|
||||
remote_block_len
|
||||
== (self.block_len_per_layer[0] * tp_ratio) // block_size_ratio
|
||||
), (
|
||||
"Remote P worker KV layer cache must be of shape [2, N, "
|
||||
"local_kv_heads*tp_ratio, page_size, head_dim] and same dtype."
|
||||
) # noqa: E501
|
||||
else:
|
||||
assert block_size_ratio == 1, (
|
||||
"Different local/remote block sizes are not supported when"
|
||||
" P TP > D TP."
|
||||
)
|
||||
# Remote tp is bigger: remote block_len size is smaller
|
||||
assert remote_block_len == self.block_len_per_layer[0] // (-tp_ratio), (
|
||||
"Remote P worker KV layer cache must be of shape [2, N, "
|
||||
"local_kv_heads/tp_ratio, page_size, head_dim] and same dtype."
|
||||
) # noqa: E501
|
||||
|
||||
# TP workers have same #blocks.
|
||||
# TP workers that handhshake with same remote have same #blocks.
|
||||
assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks
|
||||
|
||||
# Same number of regions/~layers.
|
||||
assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer)
|
||||
|
||||
def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta):
|
||||
@ -1710,7 +1773,7 @@ class NixlConnectorWorker:
|
||||
)
|
||||
cache.index_copy_(0, indices, permuted_blocks)
|
||||
|
||||
def blocksize_post_process(self, block_ids_per_ratio: dict[float, list[list[int]]]):
|
||||
def blocksize_post_process(self, block_ids_per_ratio: dict[int, list[list[int]]]):
|
||||
def _process_local_gt_remote(blocks_to_update, block_size_ratio):
|
||||
n_kv_heads, block_size, head_size = blocks_to_update.shape[1:]
|
||||
remote_block_size = block_size // block_size_ratio
|
||||
@ -1840,7 +1903,7 @@ class NixlConnectorWorker:
|
||||
notified_req_ids: set[str] = set()
|
||||
for notifs in self.nixl_wrapper.get_new_notifs().values():
|
||||
for notif in notifs:
|
||||
req_id, tp_ratio = notif.decode("utf-8").rsplit(":", 1)
|
||||
req_id, tp_size = notif.decode("utf-8").rsplit(":", 1)
|
||||
if (
|
||||
req_id not in self._reqs_to_send
|
||||
and req_id not in self._reqs_to_process
|
||||
@ -1853,9 +1916,22 @@ class NixlConnectorWorker:
|
||||
)
|
||||
continue
|
||||
|
||||
# NOTE: `tp_ratio` is the opposite when swapping local<>remote
|
||||
n_consumers = int(tp_size)
|
||||
tp_ratio = self.kv_topo.tp_ratio(n_consumers)
|
||||
|
||||
# Number of reads *per producer* to wait for.
|
||||
# When remote D TP > local P TP we expect `tp_ratio` reads.
|
||||
consumers_per_producer = (
|
||||
-tp_ratio if n_consumers > self.world_size else 1
|
||||
)
|
||||
|
||||
self.consumer_notification_counts_by_req[req_id] += 1
|
||||
# Wait all consumers (D) to be done reading before freeing.
|
||||
if self.consumer_notification_counts_by_req[req_id] == int(tp_ratio):
|
||||
if (
|
||||
self.consumer_notification_counts_by_req[req_id]
|
||||
== consumers_per_producer
|
||||
):
|
||||
notified_req_ids.add(req_id)
|
||||
del self.consumer_notification_counts_by_req[req_id]
|
||||
self._reqs_to_process.remove(req_id)
|
||||
@ -1872,7 +1948,7 @@ class NixlConnectorWorker:
|
||||
"""
|
||||
done_req_ids: set[str] = set()
|
||||
for req_id, handles in list(transfers.items()):
|
||||
in_progress = False
|
||||
in_progress = []
|
||||
for handle in handles:
|
||||
try:
|
||||
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
|
||||
@ -1882,7 +1958,7 @@ class NixlConnectorWorker:
|
||||
self.xfer_stats.record_transfer(res)
|
||||
self.nixl_wrapper.release_xfer_handle(handle)
|
||||
elif xfer_state == "PROC":
|
||||
in_progress = True
|
||||
in_progress.append(handle)
|
||||
continue
|
||||
else:
|
||||
logger.error(
|
||||
@ -1892,7 +1968,6 @@ class NixlConnectorWorker:
|
||||
xfer_state,
|
||||
)
|
||||
self._handle_failed_transfer(req_id, handle)
|
||||
in_progress = False
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"NIXL transfer exception for request %s. "
|
||||
@ -1900,11 +1975,13 @@ class NixlConnectorWorker:
|
||||
req_id,
|
||||
)
|
||||
self._handle_failed_transfer(req_id, handle)
|
||||
in_progress = False
|
||||
|
||||
if not in_progress:
|
||||
# Only report request as completed when all transfers are done.
|
||||
done_req_ids.add(req_id)
|
||||
del transfers[req_id]
|
||||
else:
|
||||
transfers[req_id] = in_progress
|
||||
return done_req_ids
|
||||
|
||||
def _handle_failed_transfer(self, req_id: str, handle: int):
|
||||
@ -1982,18 +2059,62 @@ class NixlConnectorWorker:
|
||||
|
||||
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
|
||||
assert meta.remote is not None
|
||||
logger.debug(
|
||||
"Remote agent %s available, calling _read_blocks for req %s",
|
||||
meta.remote.engine_id,
|
||||
req_id,
|
||||
)
|
||||
self._read_blocks(
|
||||
request_id=req_id,
|
||||
dst_engine_id=meta.remote.engine_id,
|
||||
remote_request_id=meta.remote.request_id,
|
||||
local_block_ids=meta.local_physical_block_ids,
|
||||
remote_block_ids=meta.remote.block_ids,
|
||||
remote_ranks = self.kv_topo.get_target_remote_ranks_from_engine_id(
|
||||
meta.remote.engine_id
|
||||
)
|
||||
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(meta.remote.engine_id)
|
||||
# D may have to perform multiple reads from different remote ranks.
|
||||
for i, remote_rank in enumerate(remote_ranks):
|
||||
if self.use_mla and tp_ratio < 0 and i > 0:
|
||||
# MLA opt: when P TP > D TP, only a single read is executed for
|
||||
# the first remote rank (cache is duplicated)..
|
||||
break
|
||||
|
||||
remote_block_size = self.kv_topo.remote_block_size[meta.remote.engine_id]
|
||||
logger.debug(
|
||||
"Remote agent %s available, calling _read_blocks"
|
||||
" on remote rank %s with remote block size %s for req %s",
|
||||
meta.remote.engine_id,
|
||||
remote_rank,
|
||||
remote_block_size,
|
||||
req_id,
|
||||
)
|
||||
# Get side handles.
|
||||
if tp_ratio < 0 and not self.use_mla:
|
||||
assert remote_block_size == self.block_size
|
||||
# Remote tp_size > local tp_size: we must perform multiple
|
||||
# reads. Get the memory chunk onto which we will write to.
|
||||
local_xfer_side_handle = self.src_xfer_handles_by_tp_ratio[tp_ratio][i]
|
||||
else:
|
||||
# Single read from remote, we write to the whole memory region.
|
||||
# Also handle remote block size different from local block size.
|
||||
local_xfer_side_handle = self.src_xfer_handles_by_block_size[
|
||||
remote_block_size
|
||||
]
|
||||
|
||||
# Destination handle: remote_engine_id -> remote_rank -> handle.
|
||||
remote_xfer_side_handle = self.dst_xfer_side_handles[meta.remote.engine_id][
|
||||
remote_rank
|
||||
]
|
||||
self._read_blocks(
|
||||
request_id=req_id,
|
||||
dst_engine_id=meta.remote.engine_id,
|
||||
remote_request_id=meta.remote.request_id,
|
||||
local_block_ids=meta.local_physical_block_ids,
|
||||
remote_block_ids=meta.remote.block_ids,
|
||||
remote_rank=remote_rank,
|
||||
local_xfer_side_handle=local_xfer_side_handle,
|
||||
remote_xfer_side_handle=remote_xfer_side_handle,
|
||||
)
|
||||
|
||||
if self.use_mla and tp_ratio < 0:
|
||||
# ..but we still need to notify the other remote ranks that we
|
||||
# have the blocks we need so they can update the request state.
|
||||
notif_id = f"{req_id}:{self.world_size}".encode()
|
||||
remote_agents = self._remote_agents[meta.remote.engine_id]
|
||||
for rank_to_notify, agent in remote_agents.items():
|
||||
if rank_to_notify != remote_rank:
|
||||
self.nixl_wrapper.send_notif(agent, notif_msg=notif_id)
|
||||
|
||||
def _read_blocks(
|
||||
self,
|
||||
@ -2002,7 +2123,14 @@ class NixlConnectorWorker:
|
||||
dst_engine_id: str,
|
||||
request_id: str,
|
||||
remote_request_id: str,
|
||||
remote_rank: int,
|
||||
local_xfer_side_handle: int,
|
||||
remote_xfer_side_handle: int,
|
||||
):
|
||||
"""
|
||||
Post a READ point-to-point xfer request from a single local worker to
|
||||
a single remote worker.
|
||||
"""
|
||||
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id)
|
||||
if block_size_ratio > 1:
|
||||
local_block_ids = self.get_mapped_blocks(
|
||||
@ -2031,18 +2159,14 @@ class NixlConnectorWorker:
|
||||
# saturate IB with heterogeneous TP sizes. We should remove the staging
|
||||
# blocks until we are ready.
|
||||
|
||||
# Number of D TP workers that will read from dst P. Propagate tp_ratio
|
||||
# Number of D TP workers that will read from dst P. Propagate info
|
||||
# on notification so that dst worker can wait before freeing blocks.
|
||||
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(dst_engine_id)
|
||||
notif_id = f"{remote_request_id}:{tp_ratio}".encode()
|
||||
notif_id = f"{remote_request_id}:{self.world_size}".encode()
|
||||
|
||||
# Full prefix cache hit: do not need to read remote blocks,
|
||||
# just notify P worker that we have the blocks we need.
|
||||
num_local_blocks = len(local_block_ids)
|
||||
if num_local_blocks == 0:
|
||||
remote_rank = self.kv_topo.get_target_remote_rank_from_engine_id(
|
||||
dst_engine_id
|
||||
)
|
||||
agent_name = self._remote_agents[dst_engine_id][remote_rank]
|
||||
try:
|
||||
self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id)
|
||||
@ -2062,13 +2186,6 @@ class NixlConnectorWorker:
|
||||
if num_local_blocks < num_remote_blocks:
|
||||
remote_block_ids = remote_block_ids[-num_local_blocks:]
|
||||
|
||||
# Get side handles.
|
||||
remote_block_size = self.kv_topo.remote_block_size[dst_engine_id]
|
||||
local_xfer_side_handle = self.src_xfer_side_handles.get(
|
||||
remote_block_size, self.src_xfer_side_handle
|
||||
)
|
||||
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id]
|
||||
|
||||
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
|
||||
# corresponding rank. With heterogeneous TP, fixing D>P, the D tp
|
||||
# workers will issue xfers to parts of the P worker remote kv caches.
|
||||
@ -2230,7 +2347,7 @@ class NixlConnectorWorker:
|
||||
block_ids_np, self._physical_blocks_per_logical_kv_block, block_arange
|
||||
).tolist()
|
||||
|
||||
def get_backend_aware_kv_block_len(self, layer_idx: int):
|
||||
def get_backend_aware_kv_block_len(self, layer_idx: int) -> int:
|
||||
"""
|
||||
Get the block length for one K/V element (K and V have the same size).
|
||||
|
||||
@ -2276,11 +2393,16 @@ class NixlConnectorWorker:
|
||||
for handle in handles:
|
||||
self.nixl_wrapper.release_xfer_handle(handle)
|
||||
self._recving_transfers.clear()
|
||||
if self.src_xfer_side_handle:
|
||||
self.nixl_wrapper.release_dlist_handle(self.src_xfer_side_handle)
|
||||
self.src_xfer_side_handle = 0
|
||||
for dst_xfer_side_handle in self.dst_xfer_side_handles.values():
|
||||
self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle)
|
||||
for handle in self.src_xfer_handles_by_block_size.values():
|
||||
self.nixl_wrapper.release_dlist_handle(handle)
|
||||
self.src_xfer_handles_by_block_size.clear()
|
||||
for handles in self.src_xfer_handles_by_tp_ratio.values():
|
||||
for handle in handles:
|
||||
self.nixl_wrapper.release_dlist_handle(handle)
|
||||
self.src_xfer_handles_by_tp_ratio.clear()
|
||||
for dst_xfer_side_handles in self.dst_xfer_side_handles.values():
|
||||
for dst_xfer_side_handle in dst_xfer_side_handles.values():
|
||||
self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle)
|
||||
self.dst_xfer_side_handles.clear()
|
||||
for remote_agents in self._remote_agents.values():
|
||||
for agent_name in remote_agents.values():
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user