From b36893b3056256e6929abc441f40e303937930f1 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Tue, 23 Dec 2025 10:09:45 +0000 Subject: [PATCH] refine ut Signed-off-by: inkcherry --- .../unit/test_moriio_connector.py | 141 ++++++++++++------ 1 file changed, 94 insertions(+), 47 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_moriio_connector.py b/tests/v1/kv_connector/unit/test_moriio_connector.py index 25b5663098272..c31d5d843e85a 100644 --- a/tests/v1/kv_connector/unit/test_moriio_connector.py +++ b/tests/v1/kv_connector/unit/test_moriio_connector.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import importlib.util import os from unittest.mock import MagicMock, patch @@ -28,6 +29,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector import MoRIIOConnector, MoRIIOConnectorWorker, ) +from vllm.platforms import current_platform from vllm.utils.network_utils import ( get_ip, make_zmq_path, @@ -35,10 +37,17 @@ from vllm.utils.network_utils import ( from .utils import create_request, create_scheduler +aiter_available = importlib.util.find_spec("aiter") is not None +mori_available = importlib.util.find_spec("mori") is not None +pytestmark = pytest.mark.skipif( + not (current_platform.is_rocm() and mori_available), + reason="MoRIIOs are only available on ROCm with aiter package installed", +) + @pytest.fixture def mock_parallel_groups(): - """Mock parallel group functions.""" + """Mock tensor/data parallel group functions for single-rank tests.""" mock_group = MagicMock() mock_group.rank = 0 mock_group.local_rank = 0 @@ -76,6 +85,7 @@ def _setup_kv_transfer_request(request, remote_host="127.0.0.1", fake_port=4789) class FakeMorIIOWrapper: + # A fake MoRIIOWrapper for testing purposes def __init__(self, *args, **kwargs): pass @@ -144,6 +154,7 @@ class FakeMorIIOWrapper: class FakeMorIIOConnectorWorker(MoRIIOConnectorWorker): + # Define a fake remote engine id for testing REMOTE_ENGINE_ID = "remote_engine" def __init__( @@ -162,7 +173,7 @@ def create_vllm_config( enable_permute_local_kv: bool = False, role="kv_consumer", ) -> VllmConfig: - """Initialize VllmConfig For Testing.""" + """Initialize VllmConfig for testing.""" scheduler_config = SchedulerConfig( max_num_seqs=max_num_seqs, max_num_batched_tokens=max_num_batched_tokens, @@ -199,18 +210,18 @@ def create_vllm_config( @pytest.fixture def moriio_read_mode(): - """Set VLLM_MORIIO_CONNECTOR_READ_MODE environment variable for all tests.""" + """Force the connector into read mode via env for tests.""" os.environ["VLLM_MORIIO_CONNECTOR_READ_MODE"] = "True" yield # Cleanup after test os.environ.pop("VLLM_MORIIO_CONNECTOR_READ_MODE", None) -def test_write_mode_basic_interface(): - """Unit test for basic MoriioConnector interface functionality.""" +def test_write_mode_saves_local_block_ids(): + """Write mode records local block ids in MoRIIOConnectorMetadata.reqs_to_save.""" - # Test Prefill wirte metadata - vllm_config = create_vllm_config(role="kv_consumer") + # Setup Scheduler and Request + vllm_config = create_vllm_config(role="kv_producer") scheduler = create_scheduler(vllm_config) # 2 Full Blocks and 1 Half Block. @@ -235,13 +246,21 @@ def test_write_mode_basic_interface(): # Remote Prefill, triggers NixlConnectorMetadata. scheduler_output = scheduler.schedule() kv_connector_metadata = scheduler_output.kv_connector_metadata - assert kv_connector_metadata is not None + assert kv_connector_metadata is not None, "kv_connector_metadata is None" assert isinstance(kv_connector_metadata, MoRIIOConnectorMetadata) - assert len(kv_connector_metadata.reqs_to_save) == 1 - assert len(kv_connector_metadata.reqs_to_recv) == 0 - assert len(kv_connector_metadata.reqs_to_send) == 0 - assert request_id in kv_connector_metadata.reqs_to_save + assert len(kv_connector_metadata.reqs_to_save) == 1, ( + "Unexpected number of reqs_to_save" + ) + assert len(kv_connector_metadata.reqs_to_recv) == 0, ( + "Unexpected number of reqs_to_recv" + ) + assert len(kv_connector_metadata.reqs_to_send) == 0, ( + "Unexpected number of reqs_to_send" + ) + assert request_id in kv_connector_metadata.reqs_to_save, ( + "Request ID not in reqs_to_save" + ) req_meta = kv_connector_metadata.reqs_to_save[request_id] for block_id, block in zip( @@ -250,18 +269,17 @@ def test_write_mode_basic_interface(): request_id ], ): - assert block_id == block.block_id + assert block_id == block.block_id, f"{block_id} != {block.block_id}" -def test_write_mode_chunk_prefill(): - """Unit test for basic MoriioConnector interface functionality.""" +def test_write_mode_with_chunked_prefill_saves_local_block_ids(): + """Write mode with chunked prefill still records correct local block ids.""" + # Setup Scheduler and Request MAX_NUM_BATCHED_TOKENS = 64 - NUM_TOKENS = MAX_NUM_BATCHED_TOKENS * 2 + MAX_NUM_BATCHED_TOKENS // 2 - # Test Prefill wirte metadata vllm_config = create_vllm_config( - max_num_batched_tokens=MAX_NUM_BATCHED_TOKENS, role="kv_consumer" + max_num_batched_tokens=MAX_NUM_BATCHED_TOKENS, role="kv_producer" ) BLOCK_SIZE = vllm_config.cache_config.block_size @@ -281,18 +299,22 @@ def test_write_mode_chunk_prefill(): scheduler.add_request(request) # Fake Config - request = _setup_kv_transfer_request(request) - # Remote Prefill, triggers NixlConnectorMetadata. - scheduler_output = scheduler.schedule() - kv_connector_metadata = scheduler_output.kv_connector_metadata - assert kv_connector_metadata is not None - assert isinstance(kv_connector_metadata, MoRIIOConnectorMetadata) - assert len(kv_connector_metadata.reqs_to_save) == 1 - assert len(kv_connector_metadata.reqs_to_recv) == 0 - assert len(kv_connector_metadata.reqs_to_send) == 0 - assert request_id in kv_connector_metadata.reqs_to_save + # Remote Prefill with chunked prefill, triggers multiple schedules. + expected_counts = [(0, 0, 0), (0, 0, 0), (1, 0, 0)] + kv_connector_metadata = None + for _, (expected_save, expected_recv, expected_send) in enumerate(expected_counts): + scheduler_output = scheduler.schedule() + kv_connector_metadata = scheduler_output.kv_connector_metadata + + assert len(kv_connector_metadata.reqs_to_save) == expected_save + assert len(kv_connector_metadata.reqs_to_recv) == expected_recv + assert len(kv_connector_metadata.reqs_to_send) == expected_send + assert kv_connector_metadata is not None, "kv_connector_metadata is None" + assert request_id in kv_connector_metadata.reqs_to_save, ( + "Request ID not in reqs_to_save" + ) req_meta = kv_connector_metadata.reqs_to_save[request_id] for block_id, block in zip( @@ -301,14 +323,16 @@ def test_write_mode_chunk_prefill(): request_id ], ): - assert block_id == block.block_id + assert block_id == block.block_id, f"{block_id} != {block.block_id}" -def test_read_mode_basic_interface(moriio_read_mode): - # test decode read +def test_read_mode_loads_remote_block_ids(moriio_read_mode): + """Read mode loads remote block ids into local cache mapping.""" + + # Setup Scheduler and Request vllm_config = create_vllm_config(role="kv_consumer") scheduler = create_scheduler(vllm_config) - # + # 2 Full Blocks and 1 Half Block. BLOCK_SIZE = vllm_config.cache_config.block_size NUM_EXTERNAL_FULL_BLOCKS = 2 @@ -327,20 +351,32 @@ def test_read_mode_basic_interface(moriio_read_mode): block_list = scheduler.kv_cache_manager.coordinator.single_type_managers[ 0 ].req_to_blocks[request_id] - # Fake kv config + request = _setup_kv_transfer_request(request) + + # Set remote block ids to be fetched. request.kv_transfer_params["remote_block_ids"] = block_list # Remote Prefill, triggers MorIIOConnectorMetadata. scheduler_output = scheduler.schedule() kv_connector_metadata = scheduler_output.kv_connector_metadata - assert kv_connector_metadata is not None - assert isinstance(kv_connector_metadata, MoRIIOConnectorMetadata) - assert len(kv_connector_metadata.reqs_to_save) == 0 - assert len(kv_connector_metadata.reqs_to_recv) == 1 - assert len(kv_connector_metadata.reqs_to_send) == 0 - assert request_id in kv_connector_metadata.reqs_to_recv + assert kv_connector_metadata is not None, "kv_connector_metadata is None" + assert isinstance(kv_connector_metadata, MoRIIOConnectorMetadata), ( + "kv_connector_metadata is not MoRIIOConnectorMetadata" + ) + assert len(kv_connector_metadata.reqs_to_save) == 0, ( + "Unexpected number of reqs_to_save" + ) + assert len(kv_connector_metadata.reqs_to_recv) == 1, ( + "Unexpected number of reqs_to_recv" + ) + assert len(kv_connector_metadata.reqs_to_send) == 0, ( + "Unexpected number of reqs_to_send" + ) + assert request_id in kv_connector_metadata.reqs_to_recv, ( + "Request ID not in reqs_to_recv" + ) req_meta = kv_connector_metadata.reqs_to_recv[request_id] for block_id, block in zip( @@ -349,10 +385,14 @@ def test_read_mode_basic_interface(moriio_read_mode): request_id ], ): - assert block_id == block.block_id + assert block_id == block.block_id, f"{block_id} != {block.block_id}" +@pytest.mark.skipif( + not aiter_available, reason="Requires aiter package for ROCm FlashAttention backend" +) def test_register_kv_caches(mock_parallel_groups): + """Test that MoRIIOConnector.register_kv_caches correctly registers kv caches.""" ROLE = "kv_consumer" IP = get_ip() vllm_config = create_vllm_config(role=ROLE) @@ -403,10 +443,8 @@ def test_register_kv_caches(mock_parallel_groups): # Execute register_kv_caches connector.register_kv_caches(kv_caches) - shared_tensor[0].data_ptr() - unique_tensor[1].data_ptr() - shared_tensor[0].data_ptr() + # Verify that the MemoryDesc stored in layer_name_to_local_kv_cache_metadata assert ( shared_tensor.data_ptr() == MemoryDesc.unpack( @@ -431,8 +469,9 @@ def test_register_kv_caches(mock_parallel_groups): ][0] ).data ) - expected_engine_key = f"{ROLE[3:]}:{IP}:{DEFAULT_PORT}:tp{TP_RANK}:dp{DP_RANK}" + # Verify engine keys + expected_engine_key = f"{ROLE[3:]}:{IP}:{DEFAULT_PORT}:tp{TP_RANK}:dp{DP_RANK}" assert ( MemoryDesc.unpack( connector.connector_worker.layer_name_to_local_kv_cache_metadata[ @@ -443,7 +482,12 @@ def test_register_kv_caches(mock_parallel_groups): ) -def test_moriio_handshake(mock_parallel_groups): +@pytest.mark.skipif( + not aiter_available, reason="Requires aiter package for ROCm FlashAttention backend" +) +def test_moriio_handshake_returns_metadata(mock_parallel_groups): + """MoRIIO handshake socket returns valid agent metadata over ZMQ.""" + ROLE = "kv_consumer" vllm_config = create_vllm_config(role=ROLE) from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend @@ -478,11 +522,12 @@ def test_moriio_handshake(mock_parallel_groups): "handshake_port": handshake_port, } ) - connector = MoRIIOConnector(vllm_config, KVConnectorRole.WORKER) # Execute register_kv_caches connector.register_kv_caches(kv_caches) + + # Connect to handshake socket and request metadata path = make_zmq_path("tcp", "127.0.0.1", handshake_port) with zmq_ctx(zmq.DEALER, path) as sock: sock.send(MoRIIOConstants.GET_META_MSG) @@ -494,4 +539,6 @@ def test_moriio_handshake(mock_parallel_groups): metadata_bytes = received_frame[1] decoder = msgspec.msgpack.Decoder(MoRIIOAgentMetadata) metadata = decoder.decode(metadata_bytes) - assert isinstance(metadata, MoRIIOAgentMetadata) + assert isinstance(metadata, MoRIIOAgentMetadata), ( + "Decoded metadata is not MoRIIOAgentMetadata" + )