refine ut

Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
inkcherry 2025-12-23 10:09:45 +00:00
parent 94a920fb0c
commit b36893b305

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib.util
import os
from unittest.mock import MagicMock, patch
@ -28,6 +29,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector import
MoRIIOConnector,
MoRIIOConnectorWorker,
)
from vllm.platforms import current_platform
from vllm.utils.network_utils import (
get_ip,
make_zmq_path,
@ -35,10 +37,17 @@ from vllm.utils.network_utils import (
from .utils import create_request, create_scheduler
aiter_available = importlib.util.find_spec("aiter") is not None
mori_available = importlib.util.find_spec("mori") is not None
pytestmark = pytest.mark.skipif(
not (current_platform.is_rocm() and mori_available),
reason="MoRIIOs are only available on ROCm with aiter package installed",
)
@pytest.fixture
def mock_parallel_groups():
"""Mock parallel group functions."""
"""Mock tensor/data parallel group functions for single-rank tests."""
mock_group = MagicMock()
mock_group.rank = 0
mock_group.local_rank = 0
@ -76,6 +85,7 @@ def _setup_kv_transfer_request(request, remote_host="127.0.0.1", fake_port=4789)
class FakeMorIIOWrapper:
# A fake MoRIIOWrapper for testing purposes
def __init__(self, *args, **kwargs):
pass
@ -144,6 +154,7 @@ class FakeMorIIOWrapper:
class FakeMorIIOConnectorWorker(MoRIIOConnectorWorker):
# Define a fake remote engine id for testing
REMOTE_ENGINE_ID = "remote_engine"
def __init__(
@ -162,7 +173,7 @@ def create_vllm_config(
enable_permute_local_kv: bool = False,
role="kv_consumer",
) -> VllmConfig:
"""Initialize VllmConfig For Testing."""
"""Initialize VllmConfig for testing."""
scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
@ -199,18 +210,18 @@ def create_vllm_config(
@pytest.fixture
def moriio_read_mode():
"""Set VLLM_MORIIO_CONNECTOR_READ_MODE environment variable for all tests."""
"""Force the connector into read mode via env for tests."""
os.environ["VLLM_MORIIO_CONNECTOR_READ_MODE"] = "True"
yield
# Cleanup after test
os.environ.pop("VLLM_MORIIO_CONNECTOR_READ_MODE", None)
def test_write_mode_basic_interface():
"""Unit test for basic MoriioConnector interface functionality."""
def test_write_mode_saves_local_block_ids():
"""Write mode records local block ids in MoRIIOConnectorMetadata.reqs_to_save."""
# Test Prefill wirte metadata
vllm_config = create_vllm_config(role="kv_consumer")
# Setup Scheduler and Request
vllm_config = create_vllm_config(role="kv_producer")
scheduler = create_scheduler(vllm_config)
# 2 Full Blocks and 1 Half Block.
@ -235,13 +246,21 @@ def test_write_mode_basic_interface():
# Remote Prefill, triggers NixlConnectorMetadata.
scheduler_output = scheduler.schedule()
kv_connector_metadata = scheduler_output.kv_connector_metadata
assert kv_connector_metadata is not None
assert kv_connector_metadata is not None, "kv_connector_metadata is None"
assert isinstance(kv_connector_metadata, MoRIIOConnectorMetadata)
assert len(kv_connector_metadata.reqs_to_save) == 1
assert len(kv_connector_metadata.reqs_to_recv) == 0
assert len(kv_connector_metadata.reqs_to_send) == 0
assert request_id in kv_connector_metadata.reqs_to_save
assert len(kv_connector_metadata.reqs_to_save) == 1, (
"Unexpected number of reqs_to_save"
)
assert len(kv_connector_metadata.reqs_to_recv) == 0, (
"Unexpected number of reqs_to_recv"
)
assert len(kv_connector_metadata.reqs_to_send) == 0, (
"Unexpected number of reqs_to_send"
)
assert request_id in kv_connector_metadata.reqs_to_save, (
"Request ID not in reqs_to_save"
)
req_meta = kv_connector_metadata.reqs_to_save[request_id]
for block_id, block in zip(
@ -250,18 +269,17 @@ def test_write_mode_basic_interface():
request_id
],
):
assert block_id == block.block_id
assert block_id == block.block_id, f"{block_id} != {block.block_id}"
def test_write_mode_chunk_prefill():
"""Unit test for basic MoriioConnector interface functionality."""
def test_write_mode_with_chunked_prefill_saves_local_block_ids():
"""Write mode with chunked prefill still records correct local block ids."""
# Setup Scheduler and Request
MAX_NUM_BATCHED_TOKENS = 64
NUM_TOKENS = MAX_NUM_BATCHED_TOKENS * 2 + MAX_NUM_BATCHED_TOKENS // 2
# Test Prefill wirte metadata
vllm_config = create_vllm_config(
max_num_batched_tokens=MAX_NUM_BATCHED_TOKENS, role="kv_consumer"
max_num_batched_tokens=MAX_NUM_BATCHED_TOKENS, role="kv_producer"
)
BLOCK_SIZE = vllm_config.cache_config.block_size
@ -281,18 +299,22 @@ def test_write_mode_chunk_prefill():
scheduler.add_request(request)
# Fake Config
request = _setup_kv_transfer_request(request)
# Remote Prefill, triggers NixlConnectorMetadata.
scheduler_output = scheduler.schedule()
kv_connector_metadata = scheduler_output.kv_connector_metadata
assert kv_connector_metadata is not None
assert isinstance(kv_connector_metadata, MoRIIOConnectorMetadata)
assert len(kv_connector_metadata.reqs_to_save) == 1
assert len(kv_connector_metadata.reqs_to_recv) == 0
assert len(kv_connector_metadata.reqs_to_send) == 0
assert request_id in kv_connector_metadata.reqs_to_save
# Remote Prefill with chunked prefill, triggers multiple schedules.
expected_counts = [(0, 0, 0), (0, 0, 0), (1, 0, 0)]
kv_connector_metadata = None
for _, (expected_save, expected_recv, expected_send) in enumerate(expected_counts):
scheduler_output = scheduler.schedule()
kv_connector_metadata = scheduler_output.kv_connector_metadata
assert len(kv_connector_metadata.reqs_to_save) == expected_save
assert len(kv_connector_metadata.reqs_to_recv) == expected_recv
assert len(kv_connector_metadata.reqs_to_send) == expected_send
assert kv_connector_metadata is not None, "kv_connector_metadata is None"
assert request_id in kv_connector_metadata.reqs_to_save, (
"Request ID not in reqs_to_save"
)
req_meta = kv_connector_metadata.reqs_to_save[request_id]
for block_id, block in zip(
@ -301,14 +323,16 @@ def test_write_mode_chunk_prefill():
request_id
],
):
assert block_id == block.block_id
assert block_id == block.block_id, f"{block_id} != {block.block_id}"
def test_read_mode_basic_interface(moriio_read_mode):
# test decode read
def test_read_mode_loads_remote_block_ids(moriio_read_mode):
"""Read mode loads remote block ids into local cache mapping."""
# Setup Scheduler and Request
vllm_config = create_vllm_config(role="kv_consumer")
scheduler = create_scheduler(vllm_config)
#
# 2 Full Blocks and 1 Half Block.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
@ -327,20 +351,32 @@ def test_read_mode_basic_interface(moriio_read_mode):
block_list = scheduler.kv_cache_manager.coordinator.single_type_managers[
0
].req_to_blocks[request_id]
# Fake kv config
request = _setup_kv_transfer_request(request)
# Set remote block ids to be fetched.
request.kv_transfer_params["remote_block_ids"] = block_list
# Remote Prefill, triggers MorIIOConnectorMetadata.
scheduler_output = scheduler.schedule()
kv_connector_metadata = scheduler_output.kv_connector_metadata
assert kv_connector_metadata is not None
assert isinstance(kv_connector_metadata, MoRIIOConnectorMetadata)
assert len(kv_connector_metadata.reqs_to_save) == 0
assert len(kv_connector_metadata.reqs_to_recv) == 1
assert len(kv_connector_metadata.reqs_to_send) == 0
assert request_id in kv_connector_metadata.reqs_to_recv
assert kv_connector_metadata is not None, "kv_connector_metadata is None"
assert isinstance(kv_connector_metadata, MoRIIOConnectorMetadata), (
"kv_connector_metadata is not MoRIIOConnectorMetadata"
)
assert len(kv_connector_metadata.reqs_to_save) == 0, (
"Unexpected number of reqs_to_save"
)
assert len(kv_connector_metadata.reqs_to_recv) == 1, (
"Unexpected number of reqs_to_recv"
)
assert len(kv_connector_metadata.reqs_to_send) == 0, (
"Unexpected number of reqs_to_send"
)
assert request_id in kv_connector_metadata.reqs_to_recv, (
"Request ID not in reqs_to_recv"
)
req_meta = kv_connector_metadata.reqs_to_recv[request_id]
for block_id, block in zip(
@ -349,10 +385,14 @@ def test_read_mode_basic_interface(moriio_read_mode):
request_id
],
):
assert block_id == block.block_id
assert block_id == block.block_id, f"{block_id} != {block.block_id}"
@pytest.mark.skipif(
not aiter_available, reason="Requires aiter package for ROCm FlashAttention backend"
)
def test_register_kv_caches(mock_parallel_groups):
"""Test that MoRIIOConnector.register_kv_caches correctly registers kv caches."""
ROLE = "kv_consumer"
IP = get_ip()
vllm_config = create_vllm_config(role=ROLE)
@ -403,10 +443,8 @@ def test_register_kv_caches(mock_parallel_groups):
# Execute register_kv_caches
connector.register_kv_caches(kv_caches)
shared_tensor[0].data_ptr()
unique_tensor[1].data_ptr()
shared_tensor[0].data_ptr()
# Verify that the MemoryDesc stored in layer_name_to_local_kv_cache_metadata
assert (
shared_tensor.data_ptr()
== MemoryDesc.unpack(
@ -431,8 +469,9 @@ def test_register_kv_caches(mock_parallel_groups):
][0]
).data
)
expected_engine_key = f"{ROLE[3:]}:{IP}:{DEFAULT_PORT}:tp{TP_RANK}:dp{DP_RANK}"
# Verify engine keys
expected_engine_key = f"{ROLE[3:]}:{IP}:{DEFAULT_PORT}:tp{TP_RANK}:dp{DP_RANK}"
assert (
MemoryDesc.unpack(
connector.connector_worker.layer_name_to_local_kv_cache_metadata[
@ -443,7 +482,12 @@ def test_register_kv_caches(mock_parallel_groups):
)
def test_moriio_handshake(mock_parallel_groups):
@pytest.mark.skipif(
not aiter_available, reason="Requires aiter package for ROCm FlashAttention backend"
)
def test_moriio_handshake_returns_metadata(mock_parallel_groups):
"""MoRIIO handshake socket returns valid agent metadata over ZMQ."""
ROLE = "kv_consumer"
vllm_config = create_vllm_config(role=ROLE)
from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend
@ -478,11 +522,12 @@ def test_moriio_handshake(mock_parallel_groups):
"handshake_port": handshake_port,
}
)
connector = MoRIIOConnector(vllm_config, KVConnectorRole.WORKER)
# Execute register_kv_caches
connector.register_kv_caches(kv_caches)
# Connect to handshake socket and request metadata
path = make_zmq_path("tcp", "127.0.0.1", handshake_port)
with zmq_ctx(zmq.DEALER, path) as sock:
sock.send(MoRIIOConstants.GET_META_MSG)
@ -494,4 +539,6 @@ def test_moriio_handshake(mock_parallel_groups):
metadata_bytes = received_frame[1]
decoder = msgspec.msgpack.Decoder(MoRIIOAgentMetadata)
metadata = decoder.decode(metadata_bytes)
assert isinstance(metadata, MoRIIOAgentMetadata)
assert isinstance(metadata, MoRIIOAgentMetadata), (
"Decoded metadata is not MoRIIOAgentMetadata"
)