mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-10 23:17:09 +08:00
add and update tests
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
9f38fed93c
commit
6601c9c5be
@ -291,6 +291,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||
def __init__(self, *args, hand_shake_latency: float = 1.8, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._hand_shake_latency = hand_shake_latency
|
||||
self.kv_cache_layout = "HND"
|
||||
|
||||
def _nixl_handshake(
|
||||
self, host: str, port: int, remote_tp_size: int, expected_engine_id: str
|
||||
@ -307,21 +308,42 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||
|
||||
assert expected_engine_id == self.REMOTE_ENGINE_ID
|
||||
|
||||
remote_agent_name = self.add_remote_agent(
|
||||
NixlAgentMetadata(
|
||||
engine_id=self.REMOTE_ENGINE_ID,
|
||||
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
||||
kv_caches_base_addr=[0],
|
||||
num_blocks=1,
|
||||
block_lens=self.block_len_per_layer,
|
||||
attn_backend_name=self.backend_name,
|
||||
# `self.kv_cache_layout` is only forced to HND when vllm engine
|
||||
# is started. We mock HND here.
|
||||
kv_cache_layout="HND",
|
||||
),
|
||||
remote_tp_size=remote_tp_size,
|
||||
)
|
||||
return {0: remote_agent_name}
|
||||
# Adjust remote block length metadata to satisfy heterogeneous TP
|
||||
# invariants enforced during handshake validation.
|
||||
remote_block_lens = list(self.block_len_per_layer)
|
||||
tp_ratio = self.kv_info.tp_ratio(remote_tp_size=remote_tp_size)
|
||||
if remote_tp_size > self.world_size:
|
||||
# P TP > D TP case, block_len of remote is smaller
|
||||
remote_block_lens = [
|
||||
block_len // (-tp_ratio) for block_len in remote_block_lens
|
||||
]
|
||||
elif remote_tp_size < self.world_size:
|
||||
remote_block_lens = [
|
||||
block_len * tp_ratio for block_len in remote_block_lens
|
||||
]
|
||||
|
||||
# When remote tp_size > local tp_size, handshake with multiple
|
||||
# remote ranks.
|
||||
num_hanshakes = 1 if tp_ratio > 0 else -tp_ratio
|
||||
remote_agents: dict[int, str] = {}
|
||||
for remote_tp_rank in range(num_hanshakes):
|
||||
remote_agent_name = self.add_remote_agent(
|
||||
NixlAgentMetadata(
|
||||
engine_id=self.REMOTE_ENGINE_ID,
|
||||
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
||||
kv_caches_base_addr=[0],
|
||||
num_blocks=1,
|
||||
block_lens=remote_block_lens,
|
||||
attn_backend_name=self.backend_name,
|
||||
# `self.kv_cache_layout` is only forced to HND when vllm engine
|
||||
# is started. We mock HND here.
|
||||
kv_cache_layout="HND",
|
||||
),
|
||||
remote_tp_rank=remote_tp_rank,
|
||||
remote_tp_size=remote_tp_size,
|
||||
)
|
||||
remote_agents[remote_tp_rank] = remote_agent_name
|
||||
return remote_agents
|
||||
|
||||
|
||||
class TestNixlHandshake:
|
||||
@ -352,7 +374,13 @@ class TestNixlHandshake:
|
||||
vllm_config, connector.engine_id, hand_shake_latency=0
|
||||
)
|
||||
assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper)
|
||||
connector.connector_worker.nixl_wrapper.set_cycles_before_xfer_done(3)
|
||||
worker = connector.connector_worker
|
||||
worker.nixl_wrapper.set_cycles_before_xfer_done(3)
|
||||
# simulate handshake
|
||||
worker.dst_xfer_side_handles = {
|
||||
FakeNixlConnectorWorker.REMOTE_ENGINE_ID: {0: 1}
|
||||
}
|
||||
worker.kv_cache_layout = "HND"
|
||||
num_xfers = 4
|
||||
while True:
|
||||
# For the same request_id, initiate multiple xfers across different
|
||||
@ -464,6 +492,70 @@ class TestNixlHandshake:
|
||||
return
|
||||
raise TimeoutError("Took too long to complete async handshake.")
|
||||
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FakeNixlWrapper,
|
||||
)
|
||||
@pytest.mark.parametrize("local_tp_size", [1, 2])
|
||||
def test_prefill_tp_size_greater_than_decode_tp_size(
|
||||
self, local_tp_size: int, dist_init
|
||||
):
|
||||
"""
|
||||
Verify remote TP > local TP handshake succeeds with different
|
||||
remote configurations.
|
||||
"""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
local_tp_size = 1
|
||||
vllm_config.parallel_config.tensor_parallel_size = local_tp_size
|
||||
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(
|
||||
vllm_config, connector.engine_id, hand_shake_latency=0
|
||||
)
|
||||
worker = connector.connector_worker
|
||||
|
||||
# Minimal local registration params used by add_remote_agent
|
||||
worker.slot_size_per_layer = [4096]
|
||||
worker.block_len_per_layer = [4096 * worker.block_size]
|
||||
worker.num_blocks = 1
|
||||
worker.dst_num_blocks[worker.engine_id] = worker.num_blocks
|
||||
worker.src_blocks_data = [(0, worker.block_len_per_layer[0], worker.tp_rank)]
|
||||
|
||||
def check_handshake(remote_tp_size: int):
|
||||
tp_ratio = remote_tp_size // local_tp_size
|
||||
assert set(remote_agents.keys()) == set(range(tp_ratio))
|
||||
|
||||
remote_engine_id = worker.REMOTE_ENGINE_ID
|
||||
assert worker._tp_size[remote_engine_id] == remote_tp_size
|
||||
assert -tp_ratio == worker.kv_info.tp_ratio(remote_engine_id)
|
||||
# ensure src_xfer_side_chunked_handles is populated with tpratio chunks
|
||||
assert -tp_ratio in worker.src_xfer_side_chunked_handles
|
||||
assert len(worker.src_xfer_side_chunked_handles[-tp_ratio]) == tp_ratio
|
||||
assert remote_engine_id in worker.dst_xfer_side_handles
|
||||
assert set(worker.dst_xfer_side_handles[remote_engine_id].keys()) == set(
|
||||
range(tp_ratio)
|
||||
)
|
||||
|
||||
remote_agents = worker._nixl_handshake(
|
||||
host="localhost",
|
||||
port=1234,
|
||||
remote_tp_size=2,
|
||||
expected_engine_id=worker.REMOTE_ENGINE_ID,
|
||||
)
|
||||
check_handshake(2)
|
||||
|
||||
# NOTE flexiblity: a second remote with higher number of ranks
|
||||
# is discovered
|
||||
worker.REMOTE_ENGINE_ID = "remote_engine_2"
|
||||
remote_agents = worker._nixl_handshake(
|
||||
host="localhost",
|
||||
port=1234,
|
||||
remote_tp_size=6,
|
||||
expected_engine_id=worker.REMOTE_ENGINE_ID,
|
||||
)
|
||||
check_handshake(6)
|
||||
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FakeNixlWrapper,
|
||||
@ -564,10 +656,9 @@ class TestNixlHandshake:
|
||||
kv_cache_layout=mismatched_layout,
|
||||
)
|
||||
|
||||
# We don't check layout for homogeneous TP and MLA for now, as the
|
||||
# whole block is moved.
|
||||
worker.add_remote_agent(meta, remote_tp_size=2)
|
||||
# Layout check done for both homogeneous and heterogeneous TP.
|
||||
with pytest.raises(AssertionError):
|
||||
worker.add_remote_agent(meta, remote_tp_size=2)
|
||||
worker.add_remote_agent(meta, remote_tp_size=1)
|
||||
|
||||
|
||||
@ -1057,7 +1148,8 @@ def test_shutdown_cleans_up_resources(dist_init):
|
||||
):
|
||||
worker._recving_transfers = {"req1": [(123, time.perf_counter())]}
|
||||
worker.src_xfer_side_handle = 456
|
||||
worker.dst_xfer_side_handles = {"engine1": 789}
|
||||
worker.src_xfer_side_chunked_handles = {-2: [456]}
|
||||
worker.dst_xfer_side_handles = {"engine1": {0: 789}}
|
||||
worker._remote_agents = {"engine1": {0: "agent1"}}
|
||||
worker._registered_descs = ["desc1", "desc2"]
|
||||
|
||||
@ -1071,7 +1163,7 @@ def test_shutdown_cleans_up_resources(dist_init):
|
||||
mock_listener.join.assert_called_once_with(timeout=0)
|
||||
|
||||
mock_rel_xfer.assert_called_once_with(123)
|
||||
assert mock_rel_dlist.call_count == 2
|
||||
assert mock_rel_dlist.call_count == 3
|
||||
mock_rel_dlist.assert_any_call(456) # src handle
|
||||
mock_rel_dlist.assert_any_call(789) # dst handle
|
||||
mock_rem_agent.assert_called_once_with("agent1")
|
||||
|
||||
@ -1171,7 +1171,14 @@ class NixlConnectorWorker:
|
||||
|
||||
# Handle tp_size>num_kv_heads: replicate KV cache.
|
||||
indexes_into_remote = (
|
||||
not self.kv_info.replicates_kv_cache(engine_id) and tp_ratio < 0
|
||||
not self.kv_info.replicates_kv_cache(engine_id) and tp_ratio > 0
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Registering remote agent (%s, rank %s) memory regions with tp_ratio %s",
|
||||
engine_id,
|
||||
remote_tp_rank,
|
||||
tp_ratio,
|
||||
)
|
||||
|
||||
### (Optional) Register local agent memory regions
|
||||
@ -1724,8 +1731,8 @@ class NixlConnectorWorker:
|
||||
if self._nixl_handshake_listener_t is not None:
|
||||
self._nixl_handshake_listener_t.join(timeout=0)
|
||||
self._nixl_handshake_listener_t = None
|
||||
for handles in self._recving_transfers.values():
|
||||
for handle, _ in handles:
|
||||
for rcv_handles in self._recving_transfers.values():
|
||||
for handle, _ in rcv_handles:
|
||||
self.nixl_wrapper.release_xfer_handle(handle)
|
||||
self._recving_transfers.clear()
|
||||
if self.src_xfer_side_handle:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user