format ut

Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
inkcherry 2025-12-23 09:04:37 +00:00
parent c4dcb3475e
commit 94a920fb0c

View File

@ -1,53 +1,14 @@
import pytest # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.platforms import current_platform
import contextlib
import inspect
import os import os
import tempfile from unittest.mock import MagicMock, patch
import textwrap
import time
import uuid
from collections import defaultdict
from unittest.mock import patch
from unittest.mock import MagicMock
import pytest
import ray
import torch
import msgspec import msgspec
from vllm import LLM import pytest
from vllm.config import KVTransferConfig import torch
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator import zmq
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
MultiKVConnectorStats,
)
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector import (
KVConnectorRole,
)
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import MoRIIOConstants
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import MoRIIOAgentMetadata
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector import ( from tests.conftest import _find_free_port
MoRIIOConnector,
MoRIIOConnectorScheduler,
MoRIIOConnectorWorker
)
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import MoRIIOConnectorMetadata
from vllm.distributed.kv_transfer.kv_transfer_state import (
ensure_kv_transfer_shutdown,
has_kv_transfer_group,
)
from vllm.forward_context import ForwardContext
from vllm.platforms.interface import Platform
from vllm.sampling_params import SamplingParams
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
from vllm.v1.request import RequestStatus
from vllm.v1.structured_output import StructuredOutputManager
from .utils import create_request, create_scheduler
from vllm.config import ( from vllm.config import (
CacheConfig, CacheConfig,
DeviceConfig, DeviceConfig,
@ -56,13 +17,68 @@ from vllm.config import (
SchedulerConfig, SchedulerConfig,
VllmConfig, VllmConfig,
) )
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import (
MoRIIOAgentMetadata,
MoRIIOConnectorMetadata,
MoRIIOConstants,
zmq_ctx,
)
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector import (
KVConnectorRole,
MoRIIOConnector,
MoRIIOConnectorWorker,
)
from vllm.utils.network_utils import (
get_ip,
make_zmq_path,
)
class FakeMorIIOWrapper(): from .utils import create_request, create_scheduler
@pytest.fixture
def mock_parallel_groups():
"""Mock parallel group functions."""
mock_group = MagicMock()
mock_group.rank = 0
mock_group.local_rank = 0
mock_group.world_size = 1
with (
patch.multiple(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common",
get_tensor_model_parallel_rank=MagicMock(return_value=0),
get_tensor_model_parallel_world_size=MagicMock(return_value=0),
),
patch.multiple(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector",
get_tensor_model_parallel_world_size=MagicMock(return_value=0),
get_world_group=MagicMock(return_value=mock_group),
get_tp_group=MagicMock(return_value=mock_group),
),
):
yield mock_group
def _setup_kv_transfer_request(request, remote_host="127.0.0.1", fake_port=4789):
"""Setup KV transfer parameters for a request."""
request.kv_transfer_params.update(
{
"remote_notify_port": fake_port,
"remote_block_ids": None,
"remote_host": remote_host,
"remote_port": fake_port,
"remote_handshake_port": fake_port,
"remote_engine_id": "test_engine",
}
)
return request
class FakeMorIIOWrapper:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
pass pass
def set_moriio_engine(self, moriio_engine): def set_moriio_engine(self, moriio_engine):
pass pass
@ -76,19 +92,7 @@ class FakeMorIIOWrapper():
pass pass
def register_local_tensor(self, tensor: torch.Tensor): def register_local_tensor(self, tensor: torch.Tensor):
assert self.moriio_engine is not None, "MoRIIO engine must be set first" pass
try:
self.local_memory_metadata = self.moriio_engine.register_torch_tensor(
tensor
)
assert self.local_memory_metadata is not None, (
"register_torch_tensor returned None"
)
local_memory_metadata_packed = self.local_memory_metadata.pack()
except Exception as e:
raise MoRIIOError(f"Failed to register local memory: {e}") from e
self.local_memory_registered = True
return local_memory_metadata_packed
def get_unpack_memory_metadata(self, packed_memory_metadata): def get_unpack_memory_metadata(self, packed_memory_metadata):
pass pass
@ -98,18 +102,17 @@ class FakeMorIIOWrapper():
def read_remote_data( def read_remote_data(
self, transfer_size_byte, local_offset=0, remote_offset=0, session=None self, transfer_size_byte, local_offset=0, remote_offset=0, session=None
): ):
pass pass
def write_remote_data( def write_remote_data(
self, transfer_size_byte, local_offset=0, remote_offset=0, session=None self, transfer_size_byte, local_offset=0, remote_offset=0, session=None
): ):
pass pass
def write_remote_data_single( def write_remote_data_single(
self, transfer_size_byte, local_offset=0, remote_offset=0, sess_idx=0 self, transfer_size_byte, local_offset=0, remote_offset=0, sess_idx=0
): ):
pass pass
def waiting_for_transfer_complete(self): def waiting_for_transfer_complete(self):
@ -140,19 +143,15 @@ class FakeMorIIOWrapper():
pass pass
class FakeMoriIOConnectorWorker(MoRIIOConnectorWorker): class FakeMorIIOConnectorWorker(MoRIIOConnectorWorker):
REMOTE_ENGINE_ID = "remote_engine" REMOTE_ENGINE_ID = "remote_engine"
def __init__( def __init__(
self, *args, hand_shake_latency: float = 1.8, kv_cache_layout="HND", **kwargs self, *args, hand_shake_latency: float = 1.8, kv_cache_layout="HND", **kwargs
): ):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._hand_shake_latency = hand_shake_latency
self.kv_cache_layout = kv_cache_layout
from vllm.v1.core.sched.scheduler import Scheduler, SchedulerOutput
def create_vllm_config( def create_vllm_config(
model: str = "facebook/opt-125m", model: str = "facebook/opt-125m",
max_num_seqs: int = 16, max_num_seqs: int = 16,
@ -161,9 +160,7 @@ def create_vllm_config(
max_model_len: int = 10000, max_model_len: int = 10000,
enable_chunked_prefill: bool = True, enable_chunked_prefill: bool = True,
enable_permute_local_kv: bool = False, enable_permute_local_kv: bool = False,
role="kv_consumer" role="kv_consumer",
# role="kv_producer"
) -> VllmConfig: ) -> VllmConfig:
"""Initialize VllmConfig For Testing.""" """Initialize VllmConfig For Testing."""
scheduler_config = SchedulerConfig( scheduler_config = SchedulerConfig(
@ -198,46 +195,20 @@ def create_vllm_config(
kv_transfer_config=kv_transfer_config, kv_transfer_config=kv_transfer_config,
device_config=DeviceConfig("cpu"), device_config=DeviceConfig("cpu"),
) )
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
)
def create_scheduler(
vllm_config: VllmConfig,
num_blocks: int = 10000,
) -> Scheduler:
"""Initialize Scheduler For Testing."""
block_size = vllm_config.cache_config.block_size
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks, # A large number of blocks to hold all requests
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False)
)
],
)
vllm_config.cache_config.num_gpu_blocks = num_blocks
return Scheduler(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config),
block_size=block_size,
)
@pytest.fixture @pytest.fixture
def moriio_read_mode(): def moriio_read_mode():
"""Set VLLM_MORIIO_CONNECTOR_READ_MODE environment variable for all tests.""" """Set VLLM_MORIIO_CONNECTOR_READ_MODE environment variable for all tests."""
os.environ['VLLM_MORIIO_CONNECTOR_READ_MODE'] = 'True' os.environ["VLLM_MORIIO_CONNECTOR_READ_MODE"] = "True"
yield yield
# Cleanup after test # Cleanup after test
os.environ.pop('VLLM_MORIIO_CONNECTOR_READ_MODE', None) os.environ.pop("VLLM_MORIIO_CONNECTOR_READ_MODE", None)
def test_write_mode_basic_interface(): def test_write_mode_basic_interface():
"""Unit test for basic MoriioConnector interface functionality.""" """Unit test for basic MoriioConnector interface functionality."""
# Test Prefill wirte metadata # Test Prefill wirte metadata
vllm_config = create_vllm_config(role="kv_consumer") vllm_config = create_vllm_config(role="kv_consumer")
scheduler = create_scheduler(vllm_config) scheduler = create_scheduler(vllm_config)
@ -252,19 +223,15 @@ def test_write_mode_basic_interface():
block_size=BLOCK_SIZE, block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS, num_tokens=NUM_TOKENS,
do_remote_decode=True, do_remote_decode=True,
do_remote_prefill=False do_remote_prefill=False,
) )
request_id = request.request_id request_id = request.request_id
scheduler.add_request(request) scheduler.add_request(request)
# Fake # Fake Config
request.kv_transfer_params['remote_notify_port']=4789 request = _setup_kv_transfer_request(request)
request.kv_transfer_params['remote_block_ids']=None
request.kv_transfer_params["remote_host"]="127.0.0.1"
request.kv_transfer_params["remote_port"]=4789
request.kv_transfer_params["remote_handshake_port"]=4789
request.kv_transfer_params["remote_engine_id"]="test_engine"
# Remote Prefill, triggers NixlConnectorMetadata. # Remote Prefill, triggers NixlConnectorMetadata.
scheduler_output = scheduler.schedule() scheduler_output = scheduler.schedule()
kv_connector_metadata = scheduler_output.kv_connector_metadata kv_connector_metadata = scheduler_output.kv_connector_metadata
@ -288,36 +255,34 @@ def test_write_mode_basic_interface():
def test_write_mode_chunk_prefill(): def test_write_mode_chunk_prefill():
"""Unit test for basic MoriioConnector interface functionality.""" """Unit test for basic MoriioConnector interface functionality."""
MAX_NUM_BATCHED_TOKENS=64 MAX_NUM_BATCHED_TOKENS = 64
NUM_TOKENS = MAX_NUM_BATCHED_TOKENS*2+MAX_NUM_BATCHED_TOKENS//2 NUM_TOKENS = MAX_NUM_BATCHED_TOKENS * 2 + MAX_NUM_BATCHED_TOKENS // 2
# Test Prefill wirte metadata # Test Prefill wirte metadata
vllm_config = create_vllm_config(max_num_batched_tokens=MAX_NUM_BATCHED_TOKENS, role="kv_consumer") vllm_config = create_vllm_config(
max_num_batched_tokens=MAX_NUM_BATCHED_TOKENS, role="kv_consumer"
)
BLOCK_SIZE = vllm_config.cache_config.block_size BLOCK_SIZE = vllm_config.cache_config.block_size
scheduler = create_scheduler(vllm_config) scheduler = create_scheduler(vllm_config)
# 2 Full Blocks and 1 Half Block. # 2 Full Blocks and 1 Half Block.
request = create_request( request = create_request(
request_id=1, request_id=1,
block_size=BLOCK_SIZE, block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS, num_tokens=NUM_TOKENS,
do_remote_decode=True, do_remote_decode=True,
do_remote_prefill=False do_remote_prefill=False,
) )
request_id = request.request_id request_id = request.request_id
scheduler.add_request(request) scheduler.add_request(request)
# Fake # Fake Config
request.kv_transfer_params['remote_notify_port']=4789
request.kv_transfer_params['remote_block_ids']=None request = _setup_kv_transfer_request(request)
request.kv_transfer_params["remote_host"]="127.0.0.1"
request.kv_transfer_params["remote_port"]=4789
request.kv_transfer_params["remote_handshake_port"]=4789
request.kv_transfer_params["remote_engine_id"]="test_engine"
# Remote Prefill, triggers NixlConnectorMetadata. # Remote Prefill, triggers NixlConnectorMetadata.
scheduler_output = scheduler.schedule() scheduler_output = scheduler.schedule()
kv_connector_metadata = scheduler_output.kv_connector_metadata kv_connector_metadata = scheduler_output.kv_connector_metadata
@ -338,8 +303,8 @@ def test_write_mode_chunk_prefill():
): ):
assert block_id == block.block_id assert block_id == block.block_id
def test_read_mode_basic_interface(moriio_read_mode): def test_read_mode_basic_interface(moriio_read_mode):
# test decode read # test decode read
vllm_config = create_vllm_config(role="kv_consumer") vllm_config = create_vllm_config(role="kv_consumer")
scheduler = create_scheduler(vllm_config) scheduler = create_scheduler(vllm_config)
@ -354,20 +319,20 @@ def test_read_mode_basic_interface(moriio_read_mode):
block_size=BLOCK_SIZE, block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS, num_tokens=NUM_TOKENS,
do_remote_decode=False, do_remote_decode=False,
do_remote_prefill=True do_remote_prefill=True,
) )
request_id = request.request_id request_id = request.request_id
scheduler.add_request(request) scheduler.add_request(request)
block_list = scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[request_id] block_list = scheduler.kv_cache_manager.coordinator.single_type_managers[
# Fake 0
request.kv_transfer_params['remote_notify_port']=4789 ].req_to_blocks[request_id]
request.kv_transfer_params['remote_block_ids']=block_list # Fake kv config
request.kv_transfer_params["remote_host"]="127.0.0.1" request = _setup_kv_transfer_request(request)
request.kv_transfer_params["remote_port"]=4789 request.kv_transfer_params["remote_block_ids"] = block_list
request.kv_transfer_params["remote_handshake_port"]=4789
request.kv_transfer_params["remote_engine_id"]="test_engine"
# Remote Prefill, triggers MorIIOConnectorMetadata. # Remote Prefill, triggers MorIIOConnectorMetadata.
scheduler_output = scheduler.schedule() scheduler_output = scheduler.schedule()
kv_connector_metadata = scheduler_output.kv_connector_metadata kv_connector_metadata = scheduler_output.kv_connector_metadata
assert kv_connector_metadata is not None assert kv_connector_metadata is not None
@ -375,7 +340,6 @@ def test_read_mode_basic_interface(moriio_read_mode):
assert len(kv_connector_metadata.reqs_to_save) == 0 assert len(kv_connector_metadata.reqs_to_save) == 0
assert len(kv_connector_metadata.reqs_to_recv) == 1 assert len(kv_connector_metadata.reqs_to_recv) == 1
assert len(kv_connector_metadata.reqs_to_send) == 0 assert len(kv_connector_metadata.reqs_to_send) == 0
# assert len(kv_connector_metadata.reqs_to_save) == 1
assert request_id in kv_connector_metadata.reqs_to_recv assert request_id in kv_connector_metadata.reqs_to_recv
req_meta = kv_connector_metadata.reqs_to_recv[request_id] req_meta = kv_connector_metadata.reqs_to_recv[request_id]
@ -388,16 +352,15 @@ def test_read_mode_basic_interface(moriio_read_mode):
assert block_id == block.block_id assert block_id == block.block_id
def test_register_kv_caches(): def test_register_kv_caches(mock_parallel_groups):
from vllm.utils.network_utils import get_ip ROLE = "kv_consumer"
IP = get_ip()
ROLE="kv_consumer"
IP=get_ip()
DEFAULT_PORT=6301
vllm_config = create_vllm_config(role=ROLE) vllm_config = create_vllm_config(role=ROLE)
TP_RANK=0 DEFAULT_PORT = 6301
DP_RANK=0 TP_RANK = 0
from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend DP_RANK = 0
from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend
backend_cls = AiterFlashAttentionBackend backend_cls = AiterFlashAttentionBackend
# Create test kv cache tensors using proper backend shape # Create test kv cache tensors using proper backend shape
@ -411,97 +374,80 @@ def test_register_kv_caches():
"layer1": unique_tensor, "layer1": unique_tensor,
"layer2": shared_tensor, "layer2": shared_tensor,
} }
# Store tensor info for validation
expected_tensor_size = shared_tensor[0].element_size() * shared_tensor[0].numel()
expected_base_addrs = [
shared_tensor[0].data_ptr(),
unique_tensor[1].data_ptr(),
shared_tensor[0].data_ptr(),
]
mock_group = MagicMock()
mock_group.rank = TP_RANK # 设置 rank
mock_group.local_rank = TP_RANK
mock_group.world_size = 1 # 设置 world_size
with ( with (
patch( patch(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_engine.MoRIIOWrapper" "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.threading.Event"
) as mock_moriio_wrapper,
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common.get_tensor_model_parallel_rank",
return_value=0
),
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common.get_tensor_model_parallel_world_size",
return_value=0
),
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_tensor_model_parallel_world_size",
return_value=0
), ),
patch( patch(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_world_group", "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.threading.Thread"
return_value=mock_group
), ),
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_tp_group",
return_value=mock_group
),
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_engine.MoRIIOWrapper",
FakeMorIIOWrapper,
)
): ):
# Create connector # Create connector
vllm_config.kv_transfer_config.kv_connector_extra_config.update({ vllm_config.kv_transfer_config.kv_connector_extra_config.update(
"proxy_ip": "127.0.0.1", {
"proxy_ping_port": 12345, "proxy_ip": "127.0.0.1",
"http_port": 12346, "proxy_ping_port": 12345,
}) "http_port": 12346,
}
)
connector = MoRIIOConnector(vllm_config, KVConnectorRole.WORKER) connector = MoRIIOConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeMoriIOConnectorWorker( connector.connector_worker = FakeMorIIOConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0 vllm_config, connector.engine_id, hand_shake_latency=0
) )
# Get the mock instance
mock_wrapper_instance = mock_moriio_wrapper.return_value
# connector.connector_worker.moriio_wrapper = mock_wrapper_instance
# Reassure the shutdown() check that the thread is terminated
# mock_thread.return_value.is_alive.return_value = False
from mori.io import ( from mori.io import (
EngineDesc,
IOEngine,
MemoryDesc, MemoryDesc,
PollCqMode,
RdmaBackendConfig,
) )
# Execute register_kv_caches # Execute register_kv_caches
connector.register_kv_caches(kv_caches) connector.register_kv_caches(kv_caches)
shared_tensor[0].data_ptr() shared_tensor[0].data_ptr()
unique_tensor[1].data_ptr() unique_tensor[1].data_ptr()
shared_tensor[0].data_ptr() shared_tensor[0].data_ptr()
assert shared_tensor.data_ptr()==MemoryDesc.unpack(connector.connector_worker.layer_name_to_local_kv_cache_metadata["layer0"][0]).data assert (
assert unique_tensor.data_ptr()==MemoryDesc.unpack(connector.connector_worker.layer_name_to_local_kv_cache_metadata["layer1"][0]).data shared_tensor.data_ptr()
assert shared_tensor.data_ptr()==MemoryDesc.unpack(connector.connector_worker.layer_name_to_local_kv_cache_metadata["layer2"][0]).data == MemoryDesc.unpack(
connector.connector_worker.layer_name_to_local_kv_cache_metadata[
"layer0"
][0]
).data
)
assert (
unique_tensor.data_ptr()
== MemoryDesc.unpack(
connector.connector_worker.layer_name_to_local_kv_cache_metadata[
"layer1"
][0]
).data
)
assert (
shared_tensor.data_ptr()
== MemoryDesc.unpack(
connector.connector_worker.layer_name_to_local_kv_cache_metadata[
"layer2"
][0]
).data
)
expected_engine_key = f"{ROLE[3:]}:{IP}:{DEFAULT_PORT}:tp{TP_RANK}:dp{DP_RANK}" expected_engine_key = f"{ROLE[3:]}:{IP}:{DEFAULT_PORT}:tp{TP_RANK}:dp{DP_RANK}"
assert MemoryDesc.unpack(connector.connector_worker.layer_name_to_local_kv_cache_metadata["layer0"][0]).engine_key ==expected_engine_key assert (
MemoryDesc.unpack(
connector.connector_worker.layer_name_to_local_kv_cache_metadata[
"layer0"
][0]
).engine_key
== expected_engine_key
)
def test_moriio_handshake():
from vllm.utils.network_utils import get_ip
ROLE="kv_consumer" def test_moriio_handshake(mock_parallel_groups):
IP=get_ip() ROLE = "kv_consumer"
DEFAULT_PORT=6301
vllm_config = create_vllm_config(role=ROLE) vllm_config = create_vllm_config(role=ROLE)
TP_RANK=0 from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend
DP_RANK=0
from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend
backend_cls = AiterFlashAttentionBackend backend_cls = AiterFlashAttentionBackend
# Create test kv cache tensors using proper backend shape # Create test kv cache tensors using proper backend shape
@ -515,91 +461,37 @@ def test_moriio_handshake():
"layer1": unique_tensor, "layer1": unique_tensor,
"layer2": shared_tensor, "layer2": shared_tensor,
} }
# Store tensor info for validation
expected_tensor_size = shared_tensor[0].element_size() * shared_tensor[0].numel()
expected_base_addrs = [
shared_tensor[0].data_ptr(),
unique_tensor[1].data_ptr(),
shared_tensor[0].data_ptr(),
]
mock_group = MagicMock()
mock_group.rank = TP_RANK
mock_group.local_rank = TP_RANK
mock_group.world_size = 1
with ( with (
patch( patch(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common.get_tensor_model_parallel_rank", "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_engine.MoRIIOWrapper",
return_value=0 FakeMorIIOWrapper,
), ),
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common.get_tensor_model_parallel_world_size",
return_value=0
),
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_tensor_model_parallel_world_size",
return_value=0
),
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_world_group",
return_value=mock_group
),
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_tp_group",
return_value=mock_group
),
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_engine.MoRIIOWrapper",
FakeMorIIOWrapper,
)
): ):
handshake_port = _find_free_port()
# Create connector # Create connector
vllm_config.kv_transfer_config.kv_connector_extra_config.update({ vllm_config.kv_transfer_config.kv_connector_extra_config.update(
"proxy_ip": "127.0.0.1", {
"proxy_ping_port": 12345, "proxy_ip": "127.0.0.1",
"http_port": 12346, "proxy_ping_port": 12345,
"handshake_port":5670 "http_port": 12346,
}) "handshake_port": handshake_port,
}
)
connector = MoRIIOConnector(vllm_config, KVConnectorRole.WORKER) connector = MoRIIOConnector(vllm_config, KVConnectorRole.WORKER)
from vllm.utils.network_utils import (
get_ip,
make_zmq_path,
make_zmq_socket,
)
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import zmq_ctx
import zmq
# Reassure the shutdown() check that the thread is terminated
# mock_thread.return_value.is_alive.return_value = False
from mori.io import (
EngineDesc,
IOEngine,
MemoryDesc,
PollCqMode,
RdmaBackendConfig,
)
# Execute register_kv_caches # Execute register_kv_caches
connector.register_kv_caches(kv_caches) connector.register_kv_caches(kv_caches)
# connector.layer_name_to_local_kv_cache_metadata["layer0"] expected_base_addrs = [ path = make_zmq_path("tcp", "127.0.0.1", handshake_port)
path = make_zmq_path("tcp", "127.0.0.1", 5670)
with zmq_ctx(zmq.DEALER, path) as sock: with zmq_ctx(zmq.DEALER, path) as sock:
sock.send(MoRIIOConstants.GET_META_MSG) sock.send(MoRIIOConstants.GET_META_MSG)
received_frame = sock.recv_multipart() received_frame = sock.recv_multipart()
if len(received_frame) != 2 or received_frame[0] != b"": if len(received_frame) != 2 or received_frame[0] != b"":
raise HandshakeError(f"Unexpected frame! {received_frame = }") raise ValueError(f"Unexpected frame! {received_frame = }")
metadata_bytes = received_frame[1] metadata_bytes = received_frame[1]
decoder = msgspec.msgpack.Decoder(MoRIIOAgentMetadata) decoder = msgspec.msgpack.Decoder(MoRIIOAgentMetadata)
metadata = decoder.decode(metadata_bytes) metadata = decoder.decode(metadata_bytes)
assert isinstance(metadata, MoRIIOAgentMetadata) assert isinstance(metadata, MoRIIOAgentMetadata)