mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-03 05:27:03 +08:00
refine ut
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
parent
94a920fb0c
commit
b36893b305
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import importlib.util
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@ -28,6 +29,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector import
|
||||
MoRIIOConnector,
|
||||
MoRIIOConnectorWorker,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.network_utils import (
|
||||
get_ip,
|
||||
make_zmq_path,
|
||||
@ -35,10 +37,17 @@ from vllm.utils.network_utils import (
|
||||
|
||||
from .utils import create_request, create_scheduler
|
||||
|
||||
aiter_available = importlib.util.find_spec("aiter") is not None
|
||||
mori_available = importlib.util.find_spec("mori") is not None
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not (current_platform.is_rocm() and mori_available),
|
||||
reason="MoRIIOs are only available on ROCm with aiter package installed",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_parallel_groups():
|
||||
"""Mock parallel group functions."""
|
||||
"""Mock tensor/data parallel group functions for single-rank tests."""
|
||||
mock_group = MagicMock()
|
||||
mock_group.rank = 0
|
||||
mock_group.local_rank = 0
|
||||
@ -76,6 +85,7 @@ def _setup_kv_transfer_request(request, remote_host="127.0.0.1", fake_port=4789)
|
||||
|
||||
|
||||
class FakeMorIIOWrapper:
|
||||
# A fake MoRIIOWrapper for testing purposes
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@ -144,6 +154,7 @@ class FakeMorIIOWrapper:
|
||||
|
||||
|
||||
class FakeMorIIOConnectorWorker(MoRIIOConnectorWorker):
|
||||
# Define a fake remote engine id for testing
|
||||
REMOTE_ENGINE_ID = "remote_engine"
|
||||
|
||||
def __init__(
|
||||
@ -162,7 +173,7 @@ def create_vllm_config(
|
||||
enable_permute_local_kv: bool = False,
|
||||
role="kv_consumer",
|
||||
) -> VllmConfig:
|
||||
"""Initialize VllmConfig For Testing."""
|
||||
"""Initialize VllmConfig for testing."""
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
@ -199,18 +210,18 @@ def create_vllm_config(
|
||||
|
||||
@pytest.fixture
|
||||
def moriio_read_mode():
|
||||
"""Set VLLM_MORIIO_CONNECTOR_READ_MODE environment variable for all tests."""
|
||||
"""Force the connector into read mode via env for tests."""
|
||||
os.environ["VLLM_MORIIO_CONNECTOR_READ_MODE"] = "True"
|
||||
yield
|
||||
# Cleanup after test
|
||||
os.environ.pop("VLLM_MORIIO_CONNECTOR_READ_MODE", None)
|
||||
|
||||
|
||||
def test_write_mode_basic_interface():
|
||||
"""Unit test for basic MoriioConnector interface functionality."""
|
||||
def test_write_mode_saves_local_block_ids():
|
||||
"""Write mode records local block ids in MoRIIOConnectorMetadata.reqs_to_save."""
|
||||
|
||||
# Test Prefill wirte metadata
|
||||
vllm_config = create_vllm_config(role="kv_consumer")
|
||||
# Setup Scheduler and Request
|
||||
vllm_config = create_vllm_config(role="kv_producer")
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# 2 Full Blocks and 1 Half Block.
|
||||
@ -235,13 +246,21 @@ def test_write_mode_basic_interface():
|
||||
# Remote Prefill, triggers NixlConnectorMetadata.
|
||||
scheduler_output = scheduler.schedule()
|
||||
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
||||
assert kv_connector_metadata is not None
|
||||
assert kv_connector_metadata is not None, "kv_connector_metadata is None"
|
||||
assert isinstance(kv_connector_metadata, MoRIIOConnectorMetadata)
|
||||
|
||||
assert len(kv_connector_metadata.reqs_to_save) == 1
|
||||
assert len(kv_connector_metadata.reqs_to_recv) == 0
|
||||
assert len(kv_connector_metadata.reqs_to_send) == 0
|
||||
assert request_id in kv_connector_metadata.reqs_to_save
|
||||
assert len(kv_connector_metadata.reqs_to_save) == 1, (
|
||||
"Unexpected number of reqs_to_save"
|
||||
)
|
||||
assert len(kv_connector_metadata.reqs_to_recv) == 0, (
|
||||
"Unexpected number of reqs_to_recv"
|
||||
)
|
||||
assert len(kv_connector_metadata.reqs_to_send) == 0, (
|
||||
"Unexpected number of reqs_to_send"
|
||||
)
|
||||
assert request_id in kv_connector_metadata.reqs_to_save, (
|
||||
"Request ID not in reqs_to_save"
|
||||
)
|
||||
req_meta = kv_connector_metadata.reqs_to_save[request_id]
|
||||
|
||||
for block_id, block in zip(
|
||||
@ -250,18 +269,17 @@ def test_write_mode_basic_interface():
|
||||
request_id
|
||||
],
|
||||
):
|
||||
assert block_id == block.block_id
|
||||
assert block_id == block.block_id, f"{block_id} != {block.block_id}"
|
||||
|
||||
|
||||
def test_write_mode_chunk_prefill():
|
||||
"""Unit test for basic MoriioConnector interface functionality."""
|
||||
def test_write_mode_with_chunked_prefill_saves_local_block_ids():
|
||||
"""Write mode with chunked prefill still records correct local block ids."""
|
||||
# Setup Scheduler and Request
|
||||
MAX_NUM_BATCHED_TOKENS = 64
|
||||
|
||||
NUM_TOKENS = MAX_NUM_BATCHED_TOKENS * 2 + MAX_NUM_BATCHED_TOKENS // 2
|
||||
|
||||
# Test Prefill wirte metadata
|
||||
vllm_config = create_vllm_config(
|
||||
max_num_batched_tokens=MAX_NUM_BATCHED_TOKENS, role="kv_consumer"
|
||||
max_num_batched_tokens=MAX_NUM_BATCHED_TOKENS, role="kv_producer"
|
||||
)
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
|
||||
@ -281,18 +299,22 @@ def test_write_mode_chunk_prefill():
|
||||
scheduler.add_request(request)
|
||||
|
||||
# Fake Config
|
||||
|
||||
request = _setup_kv_transfer_request(request)
|
||||
# Remote Prefill, triggers NixlConnectorMetadata.
|
||||
scheduler_output = scheduler.schedule()
|
||||
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
||||
assert kv_connector_metadata is not None
|
||||
assert isinstance(kv_connector_metadata, MoRIIOConnectorMetadata)
|
||||
|
||||
assert len(kv_connector_metadata.reqs_to_save) == 1
|
||||
assert len(kv_connector_metadata.reqs_to_recv) == 0
|
||||
assert len(kv_connector_metadata.reqs_to_send) == 0
|
||||
assert request_id in kv_connector_metadata.reqs_to_save
|
||||
# Remote Prefill with chunked prefill, triggers multiple schedules.
|
||||
expected_counts = [(0, 0, 0), (0, 0, 0), (1, 0, 0)]
|
||||
kv_connector_metadata = None
|
||||
for _, (expected_save, expected_recv, expected_send) in enumerate(expected_counts):
|
||||
scheduler_output = scheduler.schedule()
|
||||
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
||||
|
||||
assert len(kv_connector_metadata.reqs_to_save) == expected_save
|
||||
assert len(kv_connector_metadata.reqs_to_recv) == expected_recv
|
||||
assert len(kv_connector_metadata.reqs_to_send) == expected_send
|
||||
assert kv_connector_metadata is not None, "kv_connector_metadata is None"
|
||||
assert request_id in kv_connector_metadata.reqs_to_save, (
|
||||
"Request ID not in reqs_to_save"
|
||||
)
|
||||
req_meta = kv_connector_metadata.reqs_to_save[request_id]
|
||||
|
||||
for block_id, block in zip(
|
||||
@ -301,14 +323,16 @@ def test_write_mode_chunk_prefill():
|
||||
request_id
|
||||
],
|
||||
):
|
||||
assert block_id == block.block_id
|
||||
assert block_id == block.block_id, f"{block_id} != {block.block_id}"
|
||||
|
||||
|
||||
def test_read_mode_basic_interface(moriio_read_mode):
|
||||
# test decode read
|
||||
def test_read_mode_loads_remote_block_ids(moriio_read_mode):
|
||||
"""Read mode loads remote block ids into local cache mapping."""
|
||||
|
||||
# Setup Scheduler and Request
|
||||
vllm_config = create_vllm_config(role="kv_consumer")
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
#
|
||||
|
||||
# 2 Full Blocks and 1 Half Block.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
@ -327,20 +351,32 @@ def test_read_mode_basic_interface(moriio_read_mode):
|
||||
block_list = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0
|
||||
].req_to_blocks[request_id]
|
||||
# Fake kv config
|
||||
|
||||
request = _setup_kv_transfer_request(request)
|
||||
|
||||
# Set remote block ids to be fetched.
|
||||
request.kv_transfer_params["remote_block_ids"] = block_list
|
||||
|
||||
# Remote Prefill, triggers MorIIOConnectorMetadata.
|
||||
|
||||
scheduler_output = scheduler.schedule()
|
||||
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
||||
assert kv_connector_metadata is not None
|
||||
assert isinstance(kv_connector_metadata, MoRIIOConnectorMetadata)
|
||||
assert len(kv_connector_metadata.reqs_to_save) == 0
|
||||
assert len(kv_connector_metadata.reqs_to_recv) == 1
|
||||
assert len(kv_connector_metadata.reqs_to_send) == 0
|
||||
assert request_id in kv_connector_metadata.reqs_to_recv
|
||||
assert kv_connector_metadata is not None, "kv_connector_metadata is None"
|
||||
assert isinstance(kv_connector_metadata, MoRIIOConnectorMetadata), (
|
||||
"kv_connector_metadata is not MoRIIOConnectorMetadata"
|
||||
)
|
||||
assert len(kv_connector_metadata.reqs_to_save) == 0, (
|
||||
"Unexpected number of reqs_to_save"
|
||||
)
|
||||
assert len(kv_connector_metadata.reqs_to_recv) == 1, (
|
||||
"Unexpected number of reqs_to_recv"
|
||||
)
|
||||
assert len(kv_connector_metadata.reqs_to_send) == 0, (
|
||||
"Unexpected number of reqs_to_send"
|
||||
)
|
||||
assert request_id in kv_connector_metadata.reqs_to_recv, (
|
||||
"Request ID not in reqs_to_recv"
|
||||
)
|
||||
req_meta = kv_connector_metadata.reqs_to_recv[request_id]
|
||||
|
||||
for block_id, block in zip(
|
||||
@ -349,10 +385,14 @@ def test_read_mode_basic_interface(moriio_read_mode):
|
||||
request_id
|
||||
],
|
||||
):
|
||||
assert block_id == block.block_id
|
||||
assert block_id == block.block_id, f"{block_id} != {block.block_id}"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not aiter_available, reason="Requires aiter package for ROCm FlashAttention backend"
|
||||
)
|
||||
def test_register_kv_caches(mock_parallel_groups):
|
||||
"""Test that MoRIIOConnector.register_kv_caches correctly registers kv caches."""
|
||||
ROLE = "kv_consumer"
|
||||
IP = get_ip()
|
||||
vllm_config = create_vllm_config(role=ROLE)
|
||||
@ -403,10 +443,8 @@ def test_register_kv_caches(mock_parallel_groups):
|
||||
|
||||
# Execute register_kv_caches
|
||||
connector.register_kv_caches(kv_caches)
|
||||
shared_tensor[0].data_ptr()
|
||||
unique_tensor[1].data_ptr()
|
||||
shared_tensor[0].data_ptr()
|
||||
|
||||
# Verify that the MemoryDesc stored in layer_name_to_local_kv_cache_metadata
|
||||
assert (
|
||||
shared_tensor.data_ptr()
|
||||
== MemoryDesc.unpack(
|
||||
@ -431,8 +469,9 @@ def test_register_kv_caches(mock_parallel_groups):
|
||||
][0]
|
||||
).data
|
||||
)
|
||||
expected_engine_key = f"{ROLE[3:]}:{IP}:{DEFAULT_PORT}:tp{TP_RANK}:dp{DP_RANK}"
|
||||
|
||||
# Verify engine keys
|
||||
expected_engine_key = f"{ROLE[3:]}:{IP}:{DEFAULT_PORT}:tp{TP_RANK}:dp{DP_RANK}"
|
||||
assert (
|
||||
MemoryDesc.unpack(
|
||||
connector.connector_worker.layer_name_to_local_kv_cache_metadata[
|
||||
@ -443,7 +482,12 @@ def test_register_kv_caches(mock_parallel_groups):
|
||||
)
|
||||
|
||||
|
||||
def test_moriio_handshake(mock_parallel_groups):
|
||||
@pytest.mark.skipif(
|
||||
not aiter_available, reason="Requires aiter package for ROCm FlashAttention backend"
|
||||
)
|
||||
def test_moriio_handshake_returns_metadata(mock_parallel_groups):
|
||||
"""MoRIIO handshake socket returns valid agent metadata over ZMQ."""
|
||||
|
||||
ROLE = "kv_consumer"
|
||||
vllm_config = create_vllm_config(role=ROLE)
|
||||
from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend
|
||||
@ -478,11 +522,12 @@ def test_moriio_handshake(mock_parallel_groups):
|
||||
"handshake_port": handshake_port,
|
||||
}
|
||||
)
|
||||
|
||||
connector = MoRIIOConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
|
||||
# Execute register_kv_caches
|
||||
connector.register_kv_caches(kv_caches)
|
||||
|
||||
# Connect to handshake socket and request metadata
|
||||
path = make_zmq_path("tcp", "127.0.0.1", handshake_port)
|
||||
with zmq_ctx(zmq.DEALER, path) as sock:
|
||||
sock.send(MoRIIOConstants.GET_META_MSG)
|
||||
@ -494,4 +539,6 @@ def test_moriio_handshake(mock_parallel_groups):
|
||||
metadata_bytes = received_frame[1]
|
||||
decoder = msgspec.msgpack.Decoder(MoRIIOAgentMetadata)
|
||||
metadata = decoder.decode(metadata_bytes)
|
||||
assert isinstance(metadata, MoRIIOAgentMetadata)
|
||||
assert isinstance(metadata, MoRIIOAgentMetadata), (
|
||||
"Decoded metadata is not MoRIIOAgentMetadata"
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user