From 94a920fb0c119e8a5baeb052a8f4b8045575b00f Mon Sep 17 00:00:00 2001 From: inkcherry Date: Tue, 23 Dec 2025 09:04:37 +0000 Subject: [PATCH] format ut Signed-off-by: inkcherry --- .../unit/test_moriio_connector.py | 450 +++++++----------- 1 file changed, 171 insertions(+), 279 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_moriio_connector.py b/tests/v1/kv_connector/unit/test_moriio_connector.py index d5da774f78026..25b5663098272 100644 --- a/tests/v1/kv_connector/unit/test_moriio_connector.py +++ b/tests/v1/kv_connector/unit/test_moriio_connector.py @@ -1,53 +1,14 @@ -import pytest - -from vllm.platforms import current_platform -import contextlib -import inspect +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os -import tempfile -import textwrap -import time -import uuid -from collections import defaultdict -from unittest.mock import patch -from unittest.mock import MagicMock -import pytest -import ray -import torch +from unittest.mock import MagicMock, patch + import msgspec -from vllm import LLM -from vllm.config import KVTransferConfig -from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator -from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats -from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import ( - MultiKVConnectorStats, -) -from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector import ( - KVConnectorRole, - ) -from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import MoRIIOConstants -from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import MoRIIOAgentMetadata +import pytest +import torch +import zmq -from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector import ( - MoRIIOConnector, - MoRIIOConnectorScheduler, - MoRIIOConnectorWorker -) -from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import MoRIIOConnectorMetadata - -from vllm.distributed.kv_transfer.kv_transfer_state import ( - ensure_kv_transfer_shutdown, - has_kv_transfer_group, -) -from vllm.forward_context import ForwardContext -from vllm.platforms.interface import Platform -from vllm.sampling_params import SamplingParams -from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend -from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput -from vllm.v1.request import RequestStatus -from vllm.v1.structured_output import StructuredOutputManager - -from .utils import create_request, create_scheduler +from tests.conftest import _find_free_port from vllm.config import ( CacheConfig, DeviceConfig, @@ -56,13 +17,68 @@ from vllm.config import ( SchedulerConfig, VllmConfig, ) +from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import ( + MoRIIOAgentMetadata, + MoRIIOConnectorMetadata, + MoRIIOConstants, + zmq_ctx, +) +from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector import ( + KVConnectorRole, + MoRIIOConnector, + MoRIIOConnectorWorker, +) +from vllm.utils.network_utils import ( + get_ip, + make_zmq_path, +) -class FakeMorIIOWrapper(): +from .utils import create_request, create_scheduler + +@pytest.fixture +def mock_parallel_groups(): + """Mock parallel group functions.""" + mock_group = MagicMock() + mock_group.rank = 0 + mock_group.local_rank = 0 + mock_group.world_size = 1 + + with ( + patch.multiple( + "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common", + get_tensor_model_parallel_rank=MagicMock(return_value=0), + get_tensor_model_parallel_world_size=MagicMock(return_value=0), + ), + patch.multiple( + "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector", + get_tensor_model_parallel_world_size=MagicMock(return_value=0), + get_world_group=MagicMock(return_value=mock_group), + get_tp_group=MagicMock(return_value=mock_group), + ), + ): + yield mock_group + + +def _setup_kv_transfer_request(request, remote_host="127.0.0.1", fake_port=4789): + """Setup KV transfer parameters for a request.""" + request.kv_transfer_params.update( + { + "remote_notify_port": fake_port, + "remote_block_ids": None, + "remote_host": remote_host, + "remote_port": fake_port, + "remote_handshake_port": fake_port, + "remote_engine_id": "test_engine", + } + ) + return request + + +class FakeMorIIOWrapper: def __init__(self, *args, **kwargs): pass - def set_moriio_engine(self, moriio_engine): pass @@ -76,19 +92,7 @@ class FakeMorIIOWrapper(): pass def register_local_tensor(self, tensor: torch.Tensor): - assert self.moriio_engine is not None, "MoRIIO engine must be set first" - try: - self.local_memory_metadata = self.moriio_engine.register_torch_tensor( - tensor - ) - assert self.local_memory_metadata is not None, ( - "register_torch_tensor returned None" - ) - local_memory_metadata_packed = self.local_memory_metadata.pack() - except Exception as e: - raise MoRIIOError(f"Failed to register local memory: {e}") from e - self.local_memory_registered = True - return local_memory_metadata_packed + pass def get_unpack_memory_metadata(self, packed_memory_metadata): pass @@ -98,18 +102,17 @@ class FakeMorIIOWrapper(): def read_remote_data( self, transfer_size_byte, local_offset=0, remote_offset=0, session=None - ): + ): pass def write_remote_data( self, transfer_size_byte, local_offset=0, remote_offset=0, session=None - ): + ): pass - def write_remote_data_single( self, transfer_size_byte, local_offset=0, remote_offset=0, sess_idx=0 - ): + ): pass def waiting_for_transfer_complete(self): @@ -140,19 +143,15 @@ class FakeMorIIOWrapper(): pass -class FakeMoriIOConnectorWorker(MoRIIOConnectorWorker): +class FakeMorIIOConnectorWorker(MoRIIOConnectorWorker): REMOTE_ENGINE_ID = "remote_engine" def __init__( self, *args, hand_shake_latency: float = 1.8, kv_cache_layout="HND", **kwargs ): super().__init__(*args, **kwargs) - self._hand_shake_latency = hand_shake_latency - self.kv_cache_layout = kv_cache_layout -from vllm.v1.core.sched.scheduler import Scheduler, SchedulerOutput - def create_vllm_config( model: str = "facebook/opt-125m", max_num_seqs: int = 16, @@ -161,9 +160,7 @@ def create_vllm_config( max_model_len: int = 10000, enable_chunked_prefill: bool = True, enable_permute_local_kv: bool = False, - role="kv_consumer" - # role="kv_producer" - + role="kv_consumer", ) -> VllmConfig: """Initialize VllmConfig For Testing.""" scheduler_config = SchedulerConfig( @@ -198,46 +195,20 @@ def create_vllm_config( kv_transfer_config=kv_transfer_config, device_config=DeviceConfig("cpu"), ) -from vllm.v1.kv_cache_interface import ( - FullAttentionSpec, - KVCacheConfig, - KVCacheGroupSpec, -) -def create_scheduler( - vllm_config: VllmConfig, - num_blocks: int = 10000, -) -> Scheduler: - """Initialize Scheduler For Testing.""" - block_size = vllm_config.cache_config.block_size - kv_cache_config = KVCacheConfig( - num_blocks=num_blocks, # A large number of blocks to hold all requests - kv_cache_tensors=[], - kv_cache_groups=[ - KVCacheGroupSpec( - ["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False) - ) - ], - ) - vllm_config.cache_config.num_gpu_blocks = num_blocks - return Scheduler( - vllm_config=vllm_config, - kv_cache_config=kv_cache_config, - log_stats=True, - structured_output_manager=StructuredOutputManager(vllm_config), - block_size=block_size, - ) + @pytest.fixture def moriio_read_mode(): """Set VLLM_MORIIO_CONNECTOR_READ_MODE environment variable for all tests.""" - os.environ['VLLM_MORIIO_CONNECTOR_READ_MODE'] = 'True' + os.environ["VLLM_MORIIO_CONNECTOR_READ_MODE"] = "True" yield # Cleanup after test - os.environ.pop('VLLM_MORIIO_CONNECTOR_READ_MODE', None) + os.environ.pop("VLLM_MORIIO_CONNECTOR_READ_MODE", None) + def test_write_mode_basic_interface(): """Unit test for basic MoriioConnector interface functionality.""" - + # Test Prefill wirte metadata vllm_config = create_vllm_config(role="kv_consumer") scheduler = create_scheduler(vllm_config) @@ -252,19 +223,15 @@ def test_write_mode_basic_interface(): block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS, do_remote_decode=True, - do_remote_prefill=False + do_remote_prefill=False, ) request_id = request.request_id scheduler.add_request(request) - - # Fake - request.kv_transfer_params['remote_notify_port']=4789 - request.kv_transfer_params['remote_block_ids']=None - request.kv_transfer_params["remote_host"]="127.0.0.1" - request.kv_transfer_params["remote_port"]=4789 - request.kv_transfer_params["remote_handshake_port"]=4789 - request.kv_transfer_params["remote_engine_id"]="test_engine" + + # Fake Config + request = _setup_kv_transfer_request(request) + # Remote Prefill, triggers NixlConnectorMetadata. scheduler_output = scheduler.schedule() kv_connector_metadata = scheduler_output.kv_connector_metadata @@ -288,36 +255,34 @@ def test_write_mode_basic_interface(): def test_write_mode_chunk_prefill(): """Unit test for basic MoriioConnector interface functionality.""" - MAX_NUM_BATCHED_TOKENS=64 + MAX_NUM_BATCHED_TOKENS = 64 - NUM_TOKENS = MAX_NUM_BATCHED_TOKENS*2+MAX_NUM_BATCHED_TOKENS//2 + NUM_TOKENS = MAX_NUM_BATCHED_TOKENS * 2 + MAX_NUM_BATCHED_TOKENS // 2 # Test Prefill wirte metadata - vllm_config = create_vllm_config(max_num_batched_tokens=MAX_NUM_BATCHED_TOKENS, role="kv_consumer") + vllm_config = create_vllm_config( + max_num_batched_tokens=MAX_NUM_BATCHED_TOKENS, role="kv_consumer" + ) BLOCK_SIZE = vllm_config.cache_config.block_size scheduler = create_scheduler(vllm_config) # 2 Full Blocks and 1 Half Block. - + request = create_request( request_id=1, block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS, do_remote_decode=True, - do_remote_prefill=False + do_remote_prefill=False, ) request_id = request.request_id scheduler.add_request(request) - - # Fake - request.kv_transfer_params['remote_notify_port']=4789 - request.kv_transfer_params['remote_block_ids']=None - request.kv_transfer_params["remote_host"]="127.0.0.1" - request.kv_transfer_params["remote_port"]=4789 - request.kv_transfer_params["remote_handshake_port"]=4789 - request.kv_transfer_params["remote_engine_id"]="test_engine" + + # Fake Config + + request = _setup_kv_transfer_request(request) # Remote Prefill, triggers NixlConnectorMetadata. scheduler_output = scheduler.schedule() kv_connector_metadata = scheduler_output.kv_connector_metadata @@ -338,8 +303,8 @@ def test_write_mode_chunk_prefill(): ): assert block_id == block.block_id + def test_read_mode_basic_interface(moriio_read_mode): - # test decode read vllm_config = create_vllm_config(role="kv_consumer") scheduler = create_scheduler(vllm_config) @@ -354,20 +319,20 @@ def test_read_mode_basic_interface(moriio_read_mode): block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS, do_remote_decode=False, - do_remote_prefill=True + do_remote_prefill=True, ) request_id = request.request_id scheduler.add_request(request) - block_list = scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[request_id] - # Fake - request.kv_transfer_params['remote_notify_port']=4789 - request.kv_transfer_params['remote_block_ids']=block_list - request.kv_transfer_params["remote_host"]="127.0.0.1" - request.kv_transfer_params["remote_port"]=4789 - request.kv_transfer_params["remote_handshake_port"]=4789 - request.kv_transfer_params["remote_engine_id"]="test_engine" + block_list = scheduler.kv_cache_manager.coordinator.single_type_managers[ + 0 + ].req_to_blocks[request_id] + # Fake kv config + request = _setup_kv_transfer_request(request) + request.kv_transfer_params["remote_block_ids"] = block_list + # Remote Prefill, triggers MorIIOConnectorMetadata. + scheduler_output = scheduler.schedule() kv_connector_metadata = scheduler_output.kv_connector_metadata assert kv_connector_metadata is not None @@ -375,7 +340,6 @@ def test_read_mode_basic_interface(moriio_read_mode): assert len(kv_connector_metadata.reqs_to_save) == 0 assert len(kv_connector_metadata.reqs_to_recv) == 1 assert len(kv_connector_metadata.reqs_to_send) == 0 - # assert len(kv_connector_metadata.reqs_to_save) == 1 assert request_id in kv_connector_metadata.reqs_to_recv req_meta = kv_connector_metadata.reqs_to_recv[request_id] @@ -388,16 +352,15 @@ def test_read_mode_basic_interface(moriio_read_mode): assert block_id == block.block_id -def test_register_kv_caches(): - from vllm.utils.network_utils import get_ip - - ROLE="kv_consumer" - IP=get_ip() - DEFAULT_PORT=6301 +def test_register_kv_caches(mock_parallel_groups): + ROLE = "kv_consumer" + IP = get_ip() vllm_config = create_vllm_config(role=ROLE) - TP_RANK=0 - DP_RANK=0 - from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend + DEFAULT_PORT = 6301 + TP_RANK = 0 + DP_RANK = 0 + from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend + backend_cls = AiterFlashAttentionBackend # Create test kv cache tensors using proper backend shape @@ -411,97 +374,80 @@ def test_register_kv_caches(): "layer1": unique_tensor, "layer2": shared_tensor, } - - # Store tensor info for validation - expected_tensor_size = shared_tensor[0].element_size() * shared_tensor[0].numel() - expected_base_addrs = [ - shared_tensor[0].data_ptr(), - unique_tensor[1].data_ptr(), - shared_tensor[0].data_ptr(), - ] - mock_group = MagicMock() - mock_group.rank = TP_RANK # 设置 rank - mock_group.local_rank = TP_RANK - mock_group.world_size = 1 # 设置 world_size with ( patch( - "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_engine.MoRIIOWrapper" - ) as mock_moriio_wrapper, - patch( - "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common.get_tensor_model_parallel_rank", - return_value=0 - ), - patch( - "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common.get_tensor_model_parallel_world_size", - return_value=0 - ), - patch( - "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_tensor_model_parallel_world_size", - return_value=0 + "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.threading.Event" ), patch( - "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_world_group", - return_value=mock_group + "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.threading.Thread" ), - patch( - "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_tp_group", - return_value=mock_group - ), - patch( - "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_engine.MoRIIOWrapper", - FakeMorIIOWrapper, - ) - ): # Create connector - vllm_config.kv_transfer_config.kv_connector_extra_config.update({ - "proxy_ip": "127.0.0.1", - "proxy_ping_port": 12345, - "http_port": 12346, - }) - + vllm_config.kv_transfer_config.kv_connector_extra_config.update( + { + "proxy_ip": "127.0.0.1", + "proxy_ping_port": 12345, + "http_port": 12346, + } + ) + connector = MoRIIOConnector(vllm_config, KVConnectorRole.WORKER) - connector.connector_worker = FakeMoriIOConnectorWorker( + connector.connector_worker = FakeMorIIOConnectorWorker( vllm_config, connector.engine_id, hand_shake_latency=0 ) - # Get the mock instance - mock_wrapper_instance = mock_moriio_wrapper.return_value - # connector.connector_worker.moriio_wrapper = mock_wrapper_instance - - # Reassure the shutdown() check that the thread is terminated - # mock_thread.return_value.is_alive.return_value = False from mori.io import ( - EngineDesc, - IOEngine, MemoryDesc, - PollCqMode, - RdmaBackendConfig, ) + # Execute register_kv_caches connector.register_kv_caches(kv_caches) shared_tensor[0].data_ptr() unique_tensor[1].data_ptr() shared_tensor[0].data_ptr() - assert shared_tensor.data_ptr()==MemoryDesc.unpack(connector.connector_worker.layer_name_to_local_kv_cache_metadata["layer0"][0]).data - assert unique_tensor.data_ptr()==MemoryDesc.unpack(connector.connector_worker.layer_name_to_local_kv_cache_metadata["layer1"][0]).data - assert shared_tensor.data_ptr()==MemoryDesc.unpack(connector.connector_worker.layer_name_to_local_kv_cache_metadata["layer2"][0]).data + assert ( + shared_tensor.data_ptr() + == MemoryDesc.unpack( + connector.connector_worker.layer_name_to_local_kv_cache_metadata[ + "layer0" + ][0] + ).data + ) + assert ( + unique_tensor.data_ptr() + == MemoryDesc.unpack( + connector.connector_worker.layer_name_to_local_kv_cache_metadata[ + "layer1" + ][0] + ).data + ) + assert ( + shared_tensor.data_ptr() + == MemoryDesc.unpack( + connector.connector_worker.layer_name_to_local_kv_cache_metadata[ + "layer2" + ][0] + ).data + ) expected_engine_key = f"{ROLE[3:]}:{IP}:{DEFAULT_PORT}:tp{TP_RANK}:dp{DP_RANK}" - assert MemoryDesc.unpack(connector.connector_worker.layer_name_to_local_kv_cache_metadata["layer0"][0]).engine_key ==expected_engine_key + assert ( + MemoryDesc.unpack( + connector.connector_worker.layer_name_to_local_kv_cache_metadata[ + "layer0" + ][0] + ).engine_key + == expected_engine_key + ) -def test_moriio_handshake(): - from vllm.utils.network_utils import get_ip - ROLE="kv_consumer" - IP=get_ip() - DEFAULT_PORT=6301 +def test_moriio_handshake(mock_parallel_groups): + ROLE = "kv_consumer" vllm_config = create_vllm_config(role=ROLE) - TP_RANK=0 - DP_RANK=0 - from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend + from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend + backend_cls = AiterFlashAttentionBackend # Create test kv cache tensors using proper backend shape @@ -515,91 +461,37 @@ def test_moriio_handshake(): "layer1": unique_tensor, "layer2": shared_tensor, } - - # Store tensor info for validation - expected_tensor_size = shared_tensor[0].element_size() * shared_tensor[0].numel() - expected_base_addrs = [ - shared_tensor[0].data_ptr(), - unique_tensor[1].data_ptr(), - shared_tensor[0].data_ptr(), - ] - mock_group = MagicMock() - mock_group.rank = TP_RANK - mock_group.local_rank = TP_RANK - mock_group.world_size = 1 with ( patch( - "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common.get_tensor_model_parallel_rank", - return_value=0 + "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_engine.MoRIIOWrapper", + FakeMorIIOWrapper, ), - patch( - "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common.get_tensor_model_parallel_world_size", - return_value=0 - ), - patch( - "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_tensor_model_parallel_world_size", - return_value=0 - ), - patch( - "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_world_group", - return_value=mock_group - ), - patch( - "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_tp_group", - return_value=mock_group - ), - patch( - "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_engine.MoRIIOWrapper", - FakeMorIIOWrapper, - ) - ): + handshake_port = _find_free_port() # Create connector - vllm_config.kv_transfer_config.kv_connector_extra_config.update({ - "proxy_ip": "127.0.0.1", - "proxy_ping_port": 12345, - "http_port": 12346, - "handshake_port":5670 - }) - + vllm_config.kv_transfer_config.kv_connector_extra_config.update( + { + "proxy_ip": "127.0.0.1", + "proxy_ping_port": 12345, + "http_port": 12346, + "handshake_port": handshake_port, + } + ) - connector = MoRIIOConnector(vllm_config, KVConnectorRole.WORKER) - - from vllm.utils.network_utils import ( - get_ip, - make_zmq_path, - make_zmq_socket, - ) - from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import zmq_ctx - import zmq - - # Reassure the shutdown() check that the thread is terminated - # mock_thread.return_value.is_alive.return_value = False - from mori.io import ( - EngineDesc, - IOEngine, - MemoryDesc, - PollCqMode, - RdmaBackendConfig, - ) # Execute register_kv_caches connector.register_kv_caches(kv_caches) - # connector.layer_name_to_local_kv_cache_metadata["layer0"] expected_base_addrs = [ - - path = make_zmq_path("tcp", "127.0.0.1", 5670) + path = make_zmq_path("tcp", "127.0.0.1", handshake_port) with zmq_ctx(zmq.DEALER, path) as sock: sock.send(MoRIIOConstants.GET_META_MSG) received_frame = sock.recv_multipart() if len(received_frame) != 2 or received_frame[0] != b"": - raise HandshakeError(f"Unexpected frame! {received_frame = }") + raise ValueError(f"Unexpected frame! {received_frame = }") metadata_bytes = received_frame[1] decoder = msgspec.msgpack.Decoder(MoRIIOAgentMetadata) metadata = decoder.decode(metadata_bytes) assert isinstance(metadata, MoRIIOAgentMetadata) - - \ No newline at end of file