format ut

Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
inkcherry 2025-12-23 09:04:37 +00:00
parent c4dcb3475e
commit 94a920fb0c

View File

@ -1,53 +1,14 @@
import pytest
from vllm.platforms import current_platform
import contextlib
import inspect
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import tempfile
import textwrap
import time
import uuid
from collections import defaultdict
from unittest.mock import patch
from unittest.mock import MagicMock
import pytest
import ray
import torch
from unittest.mock import MagicMock, patch
import msgspec
from vllm import LLM
from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
MultiKVConnectorStats,
)
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector import (
KVConnectorRole,
)
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import MoRIIOConstants
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import MoRIIOAgentMetadata
import pytest
import torch
import zmq
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector import (
MoRIIOConnector,
MoRIIOConnectorScheduler,
MoRIIOConnectorWorker
)
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import MoRIIOConnectorMetadata
from vllm.distributed.kv_transfer.kv_transfer_state import (
ensure_kv_transfer_shutdown,
has_kv_transfer_group,
)
from vllm.forward_context import ForwardContext
from vllm.platforms.interface import Platform
from vllm.sampling_params import SamplingParams
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
from vllm.v1.request import RequestStatus
from vllm.v1.structured_output import StructuredOutputManager
from .utils import create_request, create_scheduler
from tests.conftest import _find_free_port
from vllm.config import (
CacheConfig,
DeviceConfig,
@ -56,13 +17,68 @@ from vllm.config import (
SchedulerConfig,
VllmConfig,
)
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import (
MoRIIOAgentMetadata,
MoRIIOConnectorMetadata,
MoRIIOConstants,
zmq_ctx,
)
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector import (
KVConnectorRole,
MoRIIOConnector,
MoRIIOConnectorWorker,
)
from vllm.utils.network_utils import (
get_ip,
make_zmq_path,
)
class FakeMorIIOWrapper():
from .utils import create_request, create_scheduler
@pytest.fixture
def mock_parallel_groups():
"""Mock parallel group functions."""
mock_group = MagicMock()
mock_group.rank = 0
mock_group.local_rank = 0
mock_group.world_size = 1
with (
patch.multiple(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common",
get_tensor_model_parallel_rank=MagicMock(return_value=0),
get_tensor_model_parallel_world_size=MagicMock(return_value=0),
),
patch.multiple(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector",
get_tensor_model_parallel_world_size=MagicMock(return_value=0),
get_world_group=MagicMock(return_value=mock_group),
get_tp_group=MagicMock(return_value=mock_group),
),
):
yield mock_group
def _setup_kv_transfer_request(request, remote_host="127.0.0.1", fake_port=4789):
"""Setup KV transfer parameters for a request."""
request.kv_transfer_params.update(
{
"remote_notify_port": fake_port,
"remote_block_ids": None,
"remote_host": remote_host,
"remote_port": fake_port,
"remote_handshake_port": fake_port,
"remote_engine_id": "test_engine",
}
)
return request
class FakeMorIIOWrapper:
def __init__(self, *args, **kwargs):
pass
def set_moriio_engine(self, moriio_engine):
pass
@ -76,19 +92,7 @@ class FakeMorIIOWrapper():
pass
def register_local_tensor(self, tensor: torch.Tensor):
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
try:
self.local_memory_metadata = self.moriio_engine.register_torch_tensor(
tensor
)
assert self.local_memory_metadata is not None, (
"register_torch_tensor returned None"
)
local_memory_metadata_packed = self.local_memory_metadata.pack()
except Exception as e:
raise MoRIIOError(f"Failed to register local memory: {e}") from e
self.local_memory_registered = True
return local_memory_metadata_packed
pass
def get_unpack_memory_metadata(self, packed_memory_metadata):
pass
@ -98,18 +102,17 @@ class FakeMorIIOWrapper():
def read_remote_data(
self, transfer_size_byte, local_offset=0, remote_offset=0, session=None
):
):
pass
def write_remote_data(
self, transfer_size_byte, local_offset=0, remote_offset=0, session=None
):
):
pass
def write_remote_data_single(
self, transfer_size_byte, local_offset=0, remote_offset=0, sess_idx=0
):
):
pass
def waiting_for_transfer_complete(self):
@ -140,19 +143,15 @@ class FakeMorIIOWrapper():
pass
class FakeMoriIOConnectorWorker(MoRIIOConnectorWorker):
class FakeMorIIOConnectorWorker(MoRIIOConnectorWorker):
REMOTE_ENGINE_ID = "remote_engine"
def __init__(
self, *args, hand_shake_latency: float = 1.8, kv_cache_layout="HND", **kwargs
):
super().__init__(*args, **kwargs)
self._hand_shake_latency = hand_shake_latency
self.kv_cache_layout = kv_cache_layout
from vllm.v1.core.sched.scheduler import Scheduler, SchedulerOutput
def create_vllm_config(
model: str = "facebook/opt-125m",
max_num_seqs: int = 16,
@ -161,9 +160,7 @@ def create_vllm_config(
max_model_len: int = 10000,
enable_chunked_prefill: bool = True,
enable_permute_local_kv: bool = False,
role="kv_consumer"
# role="kv_producer"
role="kv_consumer",
) -> VllmConfig:
"""Initialize VllmConfig For Testing."""
scheduler_config = SchedulerConfig(
@ -198,46 +195,20 @@ def create_vllm_config(
kv_transfer_config=kv_transfer_config,
device_config=DeviceConfig("cpu"),
)
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
)
def create_scheduler(
vllm_config: VllmConfig,
num_blocks: int = 10000,
) -> Scheduler:
"""Initialize Scheduler For Testing."""
block_size = vllm_config.cache_config.block_size
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks, # A large number of blocks to hold all requests
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False)
)
],
)
vllm_config.cache_config.num_gpu_blocks = num_blocks
return Scheduler(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config),
block_size=block_size,
)
@pytest.fixture
def moriio_read_mode():
"""Set VLLM_MORIIO_CONNECTOR_READ_MODE environment variable for all tests."""
os.environ['VLLM_MORIIO_CONNECTOR_READ_MODE'] = 'True'
os.environ["VLLM_MORIIO_CONNECTOR_READ_MODE"] = "True"
yield
# Cleanup after test
os.environ.pop('VLLM_MORIIO_CONNECTOR_READ_MODE', None)
os.environ.pop("VLLM_MORIIO_CONNECTOR_READ_MODE", None)
def test_write_mode_basic_interface():
"""Unit test for basic MoriioConnector interface functionality."""
# Test Prefill wirte metadata
vllm_config = create_vllm_config(role="kv_consumer")
scheduler = create_scheduler(vllm_config)
@ -252,19 +223,15 @@ def test_write_mode_basic_interface():
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_decode=True,
do_remote_prefill=False
do_remote_prefill=False,
)
request_id = request.request_id
scheduler.add_request(request)
# Fake
request.kv_transfer_params['remote_notify_port']=4789
request.kv_transfer_params['remote_block_ids']=None
request.kv_transfer_params["remote_host"]="127.0.0.1"
request.kv_transfer_params["remote_port"]=4789
request.kv_transfer_params["remote_handshake_port"]=4789
request.kv_transfer_params["remote_engine_id"]="test_engine"
# Fake Config
request = _setup_kv_transfer_request(request)
# Remote Prefill, triggers NixlConnectorMetadata.
scheduler_output = scheduler.schedule()
kv_connector_metadata = scheduler_output.kv_connector_metadata
@ -288,36 +255,34 @@ def test_write_mode_basic_interface():
def test_write_mode_chunk_prefill():
"""Unit test for basic MoriioConnector interface functionality."""
MAX_NUM_BATCHED_TOKENS=64
MAX_NUM_BATCHED_TOKENS = 64
NUM_TOKENS = MAX_NUM_BATCHED_TOKENS*2+MAX_NUM_BATCHED_TOKENS//2
NUM_TOKENS = MAX_NUM_BATCHED_TOKENS * 2 + MAX_NUM_BATCHED_TOKENS // 2
# Test Prefill wirte metadata
vllm_config = create_vllm_config(max_num_batched_tokens=MAX_NUM_BATCHED_TOKENS, role="kv_consumer")
vllm_config = create_vllm_config(
max_num_batched_tokens=MAX_NUM_BATCHED_TOKENS, role="kv_consumer"
)
BLOCK_SIZE = vllm_config.cache_config.block_size
scheduler = create_scheduler(vllm_config)
# 2 Full Blocks and 1 Half Block.
request = create_request(
request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_decode=True,
do_remote_prefill=False
do_remote_prefill=False,
)
request_id = request.request_id
scheduler.add_request(request)
# Fake
request.kv_transfer_params['remote_notify_port']=4789
request.kv_transfer_params['remote_block_ids']=None
request.kv_transfer_params["remote_host"]="127.0.0.1"
request.kv_transfer_params["remote_port"]=4789
request.kv_transfer_params["remote_handshake_port"]=4789
request.kv_transfer_params["remote_engine_id"]="test_engine"
# Fake Config
request = _setup_kv_transfer_request(request)
# Remote Prefill, triggers NixlConnectorMetadata.
scheduler_output = scheduler.schedule()
kv_connector_metadata = scheduler_output.kv_connector_metadata
@ -338,8 +303,8 @@ def test_write_mode_chunk_prefill():
):
assert block_id == block.block_id
def test_read_mode_basic_interface(moriio_read_mode):
# test decode read
vllm_config = create_vllm_config(role="kv_consumer")
scheduler = create_scheduler(vllm_config)
@ -354,20 +319,20 @@ def test_read_mode_basic_interface(moriio_read_mode):
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_decode=False,
do_remote_prefill=True
do_remote_prefill=True,
)
request_id = request.request_id
scheduler.add_request(request)
block_list = scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[request_id]
# Fake
request.kv_transfer_params['remote_notify_port']=4789
request.kv_transfer_params['remote_block_ids']=block_list
request.kv_transfer_params["remote_host"]="127.0.0.1"
request.kv_transfer_params["remote_port"]=4789
request.kv_transfer_params["remote_handshake_port"]=4789
request.kv_transfer_params["remote_engine_id"]="test_engine"
block_list = scheduler.kv_cache_manager.coordinator.single_type_managers[
0
].req_to_blocks[request_id]
# Fake kv config
request = _setup_kv_transfer_request(request)
request.kv_transfer_params["remote_block_ids"] = block_list
# Remote Prefill, triggers MorIIOConnectorMetadata.
scheduler_output = scheduler.schedule()
kv_connector_metadata = scheduler_output.kv_connector_metadata
assert kv_connector_metadata is not None
@ -375,7 +340,6 @@ def test_read_mode_basic_interface(moriio_read_mode):
assert len(kv_connector_metadata.reqs_to_save) == 0
assert len(kv_connector_metadata.reqs_to_recv) == 1
assert len(kv_connector_metadata.reqs_to_send) == 0
# assert len(kv_connector_metadata.reqs_to_save) == 1
assert request_id in kv_connector_metadata.reqs_to_recv
req_meta = kv_connector_metadata.reqs_to_recv[request_id]
@ -388,16 +352,15 @@ def test_read_mode_basic_interface(moriio_read_mode):
assert block_id == block.block_id
def test_register_kv_caches():
from vllm.utils.network_utils import get_ip
ROLE="kv_consumer"
IP=get_ip()
DEFAULT_PORT=6301
def test_register_kv_caches(mock_parallel_groups):
ROLE = "kv_consumer"
IP = get_ip()
vllm_config = create_vllm_config(role=ROLE)
TP_RANK=0
DP_RANK=0
from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend
DEFAULT_PORT = 6301
TP_RANK = 0
DP_RANK = 0
from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend
backend_cls = AiterFlashAttentionBackend
# Create test kv cache tensors using proper backend shape
@ -411,97 +374,80 @@ def test_register_kv_caches():
"layer1": unique_tensor,
"layer2": shared_tensor,
}
# Store tensor info for validation
expected_tensor_size = shared_tensor[0].element_size() * shared_tensor[0].numel()
expected_base_addrs = [
shared_tensor[0].data_ptr(),
unique_tensor[1].data_ptr(),
shared_tensor[0].data_ptr(),
]
mock_group = MagicMock()
mock_group.rank = TP_RANK # 设置 rank
mock_group.local_rank = TP_RANK
mock_group.world_size = 1 # 设置 world_size
with (
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_engine.MoRIIOWrapper"
) as mock_moriio_wrapper,
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common.get_tensor_model_parallel_rank",
return_value=0
),
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common.get_tensor_model_parallel_world_size",
return_value=0
),
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_tensor_model_parallel_world_size",
return_value=0
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.threading.Event"
),
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_world_group",
return_value=mock_group
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.threading.Thread"
),
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_tp_group",
return_value=mock_group
),
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_engine.MoRIIOWrapper",
FakeMorIIOWrapper,
)
):
# Create connector
vllm_config.kv_transfer_config.kv_connector_extra_config.update({
"proxy_ip": "127.0.0.1",
"proxy_ping_port": 12345,
"http_port": 12346,
})
vllm_config.kv_transfer_config.kv_connector_extra_config.update(
{
"proxy_ip": "127.0.0.1",
"proxy_ping_port": 12345,
"http_port": 12346,
}
)
connector = MoRIIOConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeMoriIOConnectorWorker(
connector.connector_worker = FakeMorIIOConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0
)
# Get the mock instance
mock_wrapper_instance = mock_moriio_wrapper.return_value
# connector.connector_worker.moriio_wrapper = mock_wrapper_instance
# Reassure the shutdown() check that the thread is terminated
# mock_thread.return_value.is_alive.return_value = False
from mori.io import (
EngineDesc,
IOEngine,
MemoryDesc,
PollCqMode,
RdmaBackendConfig,
)
# Execute register_kv_caches
connector.register_kv_caches(kv_caches)
shared_tensor[0].data_ptr()
unique_tensor[1].data_ptr()
shared_tensor[0].data_ptr()
assert shared_tensor.data_ptr()==MemoryDesc.unpack(connector.connector_worker.layer_name_to_local_kv_cache_metadata["layer0"][0]).data
assert unique_tensor.data_ptr()==MemoryDesc.unpack(connector.connector_worker.layer_name_to_local_kv_cache_metadata["layer1"][0]).data
assert shared_tensor.data_ptr()==MemoryDesc.unpack(connector.connector_worker.layer_name_to_local_kv_cache_metadata["layer2"][0]).data
assert (
shared_tensor.data_ptr()
== MemoryDesc.unpack(
connector.connector_worker.layer_name_to_local_kv_cache_metadata[
"layer0"
][0]
).data
)
assert (
unique_tensor.data_ptr()
== MemoryDesc.unpack(
connector.connector_worker.layer_name_to_local_kv_cache_metadata[
"layer1"
][0]
).data
)
assert (
shared_tensor.data_ptr()
== MemoryDesc.unpack(
connector.connector_worker.layer_name_to_local_kv_cache_metadata[
"layer2"
][0]
).data
)
expected_engine_key = f"{ROLE[3:]}:{IP}:{DEFAULT_PORT}:tp{TP_RANK}:dp{DP_RANK}"
assert MemoryDesc.unpack(connector.connector_worker.layer_name_to_local_kv_cache_metadata["layer0"][0]).engine_key ==expected_engine_key
assert (
MemoryDesc.unpack(
connector.connector_worker.layer_name_to_local_kv_cache_metadata[
"layer0"
][0]
).engine_key
== expected_engine_key
)
def test_moriio_handshake():
from vllm.utils.network_utils import get_ip
ROLE="kv_consumer"
IP=get_ip()
DEFAULT_PORT=6301
def test_moriio_handshake(mock_parallel_groups):
ROLE = "kv_consumer"
vllm_config = create_vllm_config(role=ROLE)
TP_RANK=0
DP_RANK=0
from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend
from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend
backend_cls = AiterFlashAttentionBackend
# Create test kv cache tensors using proper backend shape
@ -515,91 +461,37 @@ def test_moriio_handshake():
"layer1": unique_tensor,
"layer2": shared_tensor,
}
# Store tensor info for validation
expected_tensor_size = shared_tensor[0].element_size() * shared_tensor[0].numel()
expected_base_addrs = [
shared_tensor[0].data_ptr(),
unique_tensor[1].data_ptr(),
shared_tensor[0].data_ptr(),
]
mock_group = MagicMock()
mock_group.rank = TP_RANK
mock_group.local_rank = TP_RANK
mock_group.world_size = 1
with (
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common.get_tensor_model_parallel_rank",
return_value=0
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_engine.MoRIIOWrapper",
FakeMorIIOWrapper,
),
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common.get_tensor_model_parallel_world_size",
return_value=0
),
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_tensor_model_parallel_world_size",
return_value=0
),
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_world_group",
return_value=mock_group
),
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_tp_group",
return_value=mock_group
),
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_engine.MoRIIOWrapper",
FakeMorIIOWrapper,
)
):
handshake_port = _find_free_port()
# Create connector
vllm_config.kv_transfer_config.kv_connector_extra_config.update({
"proxy_ip": "127.0.0.1",
"proxy_ping_port": 12345,
"http_port": 12346,
"handshake_port":5670
})
vllm_config.kv_transfer_config.kv_connector_extra_config.update(
{
"proxy_ip": "127.0.0.1",
"proxy_ping_port": 12345,
"http_port": 12346,
"handshake_port": handshake_port,
}
)
connector = MoRIIOConnector(vllm_config, KVConnectorRole.WORKER)
from vllm.utils.network_utils import (
get_ip,
make_zmq_path,
make_zmq_socket,
)
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import zmq_ctx
import zmq
# Reassure the shutdown() check that the thread is terminated
# mock_thread.return_value.is_alive.return_value = False
from mori.io import (
EngineDesc,
IOEngine,
MemoryDesc,
PollCqMode,
RdmaBackendConfig,
)
# Execute register_kv_caches
connector.register_kv_caches(kv_caches)
# connector.layer_name_to_local_kv_cache_metadata["layer0"] expected_base_addrs = [
path = make_zmq_path("tcp", "127.0.0.1", 5670)
path = make_zmq_path("tcp", "127.0.0.1", handshake_port)
with zmq_ctx(zmq.DEALER, path) as sock:
sock.send(MoRIIOConstants.GET_META_MSG)
received_frame = sock.recv_multipart()
if len(received_frame) != 2 or received_frame[0] != b"":
raise HandshakeError(f"Unexpected frame! {received_frame = }")
raise ValueError(f"Unexpected frame! {received_frame = }")
metadata_bytes = received_frame[1]
decoder = msgspec.msgpack.Decoder(MoRIIOAgentMetadata)
metadata = decoder.decode(metadata_bytes)
assert isinstance(metadata, MoRIIOAgentMetadata)