mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-24 16:04:28 +08:00
format ut
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
parent
c4dcb3475e
commit
94a920fb0c
@ -1,53 +1,14 @@
|
|||||||
import pytest
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from vllm.platforms import current_platform
|
|
||||||
import contextlib
|
|
||||||
import inspect
|
|
||||||
import os
|
import os
|
||||||
import tempfile
|
from unittest.mock import MagicMock, patch
|
||||||
import textwrap
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from collections import defaultdict
|
|
||||||
from unittest.mock import patch
|
|
||||||
from unittest.mock import MagicMock
|
|
||||||
import pytest
|
|
||||||
import ray
|
|
||||||
import torch
|
|
||||||
import msgspec
|
import msgspec
|
||||||
from vllm import LLM
|
import pytest
|
||||||
from vllm.config import KVTransferConfig
|
import torch
|
||||||
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
import zmq
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
|
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
|
|
||||||
MultiKVConnectorStats,
|
|
||||||
)
|
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector import (
|
|
||||||
KVConnectorRole,
|
|
||||||
)
|
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import MoRIIOConstants
|
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import MoRIIOAgentMetadata
|
|
||||||
|
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector import (
|
from tests.conftest import _find_free_port
|
||||||
MoRIIOConnector,
|
|
||||||
MoRIIOConnectorScheduler,
|
|
||||||
MoRIIOConnectorWorker
|
|
||||||
)
|
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import MoRIIOConnectorMetadata
|
|
||||||
|
|
||||||
from vllm.distributed.kv_transfer.kv_transfer_state import (
|
|
||||||
ensure_kv_transfer_shutdown,
|
|
||||||
has_kv_transfer_group,
|
|
||||||
)
|
|
||||||
from vllm.forward_context import ForwardContext
|
|
||||||
from vllm.platforms.interface import Platform
|
|
||||||
from vllm.sampling_params import SamplingParams
|
|
||||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
|
||||||
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
|
||||||
from vllm.v1.request import RequestStatus
|
|
||||||
from vllm.v1.structured_output import StructuredOutputManager
|
|
||||||
|
|
||||||
from .utils import create_request, create_scheduler
|
|
||||||
from vllm.config import (
|
from vllm.config import (
|
||||||
CacheConfig,
|
CacheConfig,
|
||||||
DeviceConfig,
|
DeviceConfig,
|
||||||
@ -56,13 +17,68 @@ from vllm.config import (
|
|||||||
SchedulerConfig,
|
SchedulerConfig,
|
||||||
VllmConfig,
|
VllmConfig,
|
||||||
)
|
)
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import (
|
||||||
|
MoRIIOAgentMetadata,
|
||||||
|
MoRIIOConnectorMetadata,
|
||||||
|
MoRIIOConstants,
|
||||||
|
zmq_ctx,
|
||||||
|
)
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector import (
|
||||||
|
KVConnectorRole,
|
||||||
|
MoRIIOConnector,
|
||||||
|
MoRIIOConnectorWorker,
|
||||||
|
)
|
||||||
|
from vllm.utils.network_utils import (
|
||||||
|
get_ip,
|
||||||
|
make_zmq_path,
|
||||||
|
)
|
||||||
|
|
||||||
class FakeMorIIOWrapper():
|
from .utils import create_request, create_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_parallel_groups():
|
||||||
|
"""Mock parallel group functions."""
|
||||||
|
mock_group = MagicMock()
|
||||||
|
mock_group.rank = 0
|
||||||
|
mock_group.local_rank = 0
|
||||||
|
mock_group.world_size = 1
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.multiple(
|
||||||
|
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common",
|
||||||
|
get_tensor_model_parallel_rank=MagicMock(return_value=0),
|
||||||
|
get_tensor_model_parallel_world_size=MagicMock(return_value=0),
|
||||||
|
),
|
||||||
|
patch.multiple(
|
||||||
|
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector",
|
||||||
|
get_tensor_model_parallel_world_size=MagicMock(return_value=0),
|
||||||
|
get_world_group=MagicMock(return_value=mock_group),
|
||||||
|
get_tp_group=MagicMock(return_value=mock_group),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
yield mock_group
|
||||||
|
|
||||||
|
|
||||||
|
def _setup_kv_transfer_request(request, remote_host="127.0.0.1", fake_port=4789):
|
||||||
|
"""Setup KV transfer parameters for a request."""
|
||||||
|
request.kv_transfer_params.update(
|
||||||
|
{
|
||||||
|
"remote_notify_port": fake_port,
|
||||||
|
"remote_block_ids": None,
|
||||||
|
"remote_host": remote_host,
|
||||||
|
"remote_port": fake_port,
|
||||||
|
"remote_handshake_port": fake_port,
|
||||||
|
"remote_engine_id": "test_engine",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return request
|
||||||
|
|
||||||
|
|
||||||
|
class FakeMorIIOWrapper:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def set_moriio_engine(self, moriio_engine):
|
def set_moriio_engine(self, moriio_engine):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -76,19 +92,7 @@ class FakeMorIIOWrapper():
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def register_local_tensor(self, tensor: torch.Tensor):
|
def register_local_tensor(self, tensor: torch.Tensor):
|
||||||
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
pass
|
||||||
try:
|
|
||||||
self.local_memory_metadata = self.moriio_engine.register_torch_tensor(
|
|
||||||
tensor
|
|
||||||
)
|
|
||||||
assert self.local_memory_metadata is not None, (
|
|
||||||
"register_torch_tensor returned None"
|
|
||||||
)
|
|
||||||
local_memory_metadata_packed = self.local_memory_metadata.pack()
|
|
||||||
except Exception as e:
|
|
||||||
raise MoRIIOError(f"Failed to register local memory: {e}") from e
|
|
||||||
self.local_memory_registered = True
|
|
||||||
return local_memory_metadata_packed
|
|
||||||
|
|
||||||
def get_unpack_memory_metadata(self, packed_memory_metadata):
|
def get_unpack_memory_metadata(self, packed_memory_metadata):
|
||||||
pass
|
pass
|
||||||
@ -106,7 +110,6 @@ class FakeMorIIOWrapper():
|
|||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def write_remote_data_single(
|
def write_remote_data_single(
|
||||||
self, transfer_size_byte, local_offset=0, remote_offset=0, sess_idx=0
|
self, transfer_size_byte, local_offset=0, remote_offset=0, sess_idx=0
|
||||||
):
|
):
|
||||||
@ -140,19 +143,15 @@ class FakeMorIIOWrapper():
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class FakeMoriIOConnectorWorker(MoRIIOConnectorWorker):
|
class FakeMorIIOConnectorWorker(MoRIIOConnectorWorker):
|
||||||
REMOTE_ENGINE_ID = "remote_engine"
|
REMOTE_ENGINE_ID = "remote_engine"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, *args, hand_shake_latency: float = 1.8, kv_cache_layout="HND", **kwargs
|
self, *args, hand_shake_latency: float = 1.8, kv_cache_layout="HND", **kwargs
|
||||||
):
|
):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self._hand_shake_latency = hand_shake_latency
|
|
||||||
self.kv_cache_layout = kv_cache_layout
|
|
||||||
|
|
||||||
|
|
||||||
from vllm.v1.core.sched.scheduler import Scheduler, SchedulerOutput
|
|
||||||
|
|
||||||
def create_vllm_config(
|
def create_vllm_config(
|
||||||
model: str = "facebook/opt-125m",
|
model: str = "facebook/opt-125m",
|
||||||
max_num_seqs: int = 16,
|
max_num_seqs: int = 16,
|
||||||
@ -161,9 +160,7 @@ def create_vllm_config(
|
|||||||
max_model_len: int = 10000,
|
max_model_len: int = 10000,
|
||||||
enable_chunked_prefill: bool = True,
|
enable_chunked_prefill: bool = True,
|
||||||
enable_permute_local_kv: bool = False,
|
enable_permute_local_kv: bool = False,
|
||||||
role="kv_consumer"
|
role="kv_consumer",
|
||||||
# role="kv_producer"
|
|
||||||
|
|
||||||
) -> VllmConfig:
|
) -> VllmConfig:
|
||||||
"""Initialize VllmConfig For Testing."""
|
"""Initialize VllmConfig For Testing."""
|
||||||
scheduler_config = SchedulerConfig(
|
scheduler_config = SchedulerConfig(
|
||||||
@ -198,42 +195,16 @@ def create_vllm_config(
|
|||||||
kv_transfer_config=kv_transfer_config,
|
kv_transfer_config=kv_transfer_config,
|
||||||
device_config=DeviceConfig("cpu"),
|
device_config=DeviceConfig("cpu"),
|
||||||
)
|
)
|
||||||
from vllm.v1.kv_cache_interface import (
|
|
||||||
FullAttentionSpec,
|
|
||||||
KVCacheConfig,
|
|
||||||
KVCacheGroupSpec,
|
|
||||||
)
|
|
||||||
def create_scheduler(
|
|
||||||
vllm_config: VllmConfig,
|
|
||||||
num_blocks: int = 10000,
|
|
||||||
) -> Scheduler:
|
|
||||||
"""Initialize Scheduler For Testing."""
|
|
||||||
block_size = vllm_config.cache_config.block_size
|
|
||||||
kv_cache_config = KVCacheConfig(
|
|
||||||
num_blocks=num_blocks, # A large number of blocks to hold all requests
|
|
||||||
kv_cache_tensors=[],
|
|
||||||
kv_cache_groups=[
|
|
||||||
KVCacheGroupSpec(
|
|
||||||
["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False)
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
vllm_config.cache_config.num_gpu_blocks = num_blocks
|
|
||||||
return Scheduler(
|
|
||||||
vllm_config=vllm_config,
|
|
||||||
kv_cache_config=kv_cache_config,
|
|
||||||
log_stats=True,
|
|
||||||
structured_output_manager=StructuredOutputManager(vllm_config),
|
|
||||||
block_size=block_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def moriio_read_mode():
|
def moriio_read_mode():
|
||||||
"""Set VLLM_MORIIO_CONNECTOR_READ_MODE environment variable for all tests."""
|
"""Set VLLM_MORIIO_CONNECTOR_READ_MODE environment variable for all tests."""
|
||||||
os.environ['VLLM_MORIIO_CONNECTOR_READ_MODE'] = 'True'
|
os.environ["VLLM_MORIIO_CONNECTOR_READ_MODE"] = "True"
|
||||||
yield
|
yield
|
||||||
# Cleanup after test
|
# Cleanup after test
|
||||||
os.environ.pop('VLLM_MORIIO_CONNECTOR_READ_MODE', None)
|
os.environ.pop("VLLM_MORIIO_CONNECTOR_READ_MODE", None)
|
||||||
|
|
||||||
|
|
||||||
def test_write_mode_basic_interface():
|
def test_write_mode_basic_interface():
|
||||||
"""Unit test for basic MoriioConnector interface functionality."""
|
"""Unit test for basic MoriioConnector interface functionality."""
|
||||||
@ -252,19 +223,15 @@ def test_write_mode_basic_interface():
|
|||||||
block_size=BLOCK_SIZE,
|
block_size=BLOCK_SIZE,
|
||||||
num_tokens=NUM_TOKENS,
|
num_tokens=NUM_TOKENS,
|
||||||
do_remote_decode=True,
|
do_remote_decode=True,
|
||||||
do_remote_prefill=False
|
do_remote_prefill=False,
|
||||||
)
|
)
|
||||||
request_id = request.request_id
|
request_id = request.request_id
|
||||||
|
|
||||||
scheduler.add_request(request)
|
scheduler.add_request(request)
|
||||||
|
|
||||||
# Fake
|
# Fake Config
|
||||||
request.kv_transfer_params['remote_notify_port']=4789
|
request = _setup_kv_transfer_request(request)
|
||||||
request.kv_transfer_params['remote_block_ids']=None
|
|
||||||
request.kv_transfer_params["remote_host"]="127.0.0.1"
|
|
||||||
request.kv_transfer_params["remote_port"]=4789
|
|
||||||
request.kv_transfer_params["remote_handshake_port"]=4789
|
|
||||||
request.kv_transfer_params["remote_engine_id"]="test_engine"
|
|
||||||
# Remote Prefill, triggers NixlConnectorMetadata.
|
# Remote Prefill, triggers NixlConnectorMetadata.
|
||||||
scheduler_output = scheduler.schedule()
|
scheduler_output = scheduler.schedule()
|
||||||
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
||||||
@ -288,12 +255,14 @@ def test_write_mode_basic_interface():
|
|||||||
|
|
||||||
def test_write_mode_chunk_prefill():
|
def test_write_mode_chunk_prefill():
|
||||||
"""Unit test for basic MoriioConnector interface functionality."""
|
"""Unit test for basic MoriioConnector interface functionality."""
|
||||||
MAX_NUM_BATCHED_TOKENS=64
|
MAX_NUM_BATCHED_TOKENS = 64
|
||||||
|
|
||||||
NUM_TOKENS = MAX_NUM_BATCHED_TOKENS*2+MAX_NUM_BATCHED_TOKENS//2
|
NUM_TOKENS = MAX_NUM_BATCHED_TOKENS * 2 + MAX_NUM_BATCHED_TOKENS // 2
|
||||||
|
|
||||||
# Test Prefill wirte metadata
|
# Test Prefill wirte metadata
|
||||||
vllm_config = create_vllm_config(max_num_batched_tokens=MAX_NUM_BATCHED_TOKENS, role="kv_consumer")
|
vllm_config = create_vllm_config(
|
||||||
|
max_num_batched_tokens=MAX_NUM_BATCHED_TOKENS, role="kv_consumer"
|
||||||
|
)
|
||||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||||
|
|
||||||
scheduler = create_scheduler(vllm_config)
|
scheduler = create_scheduler(vllm_config)
|
||||||
@ -305,19 +274,15 @@ def test_write_mode_chunk_prefill():
|
|||||||
block_size=BLOCK_SIZE,
|
block_size=BLOCK_SIZE,
|
||||||
num_tokens=NUM_TOKENS,
|
num_tokens=NUM_TOKENS,
|
||||||
do_remote_decode=True,
|
do_remote_decode=True,
|
||||||
do_remote_prefill=False
|
do_remote_prefill=False,
|
||||||
)
|
)
|
||||||
request_id = request.request_id
|
request_id = request.request_id
|
||||||
|
|
||||||
scheduler.add_request(request)
|
scheduler.add_request(request)
|
||||||
|
|
||||||
# Fake
|
# Fake Config
|
||||||
request.kv_transfer_params['remote_notify_port']=4789
|
|
||||||
request.kv_transfer_params['remote_block_ids']=None
|
request = _setup_kv_transfer_request(request)
|
||||||
request.kv_transfer_params["remote_host"]="127.0.0.1"
|
|
||||||
request.kv_transfer_params["remote_port"]=4789
|
|
||||||
request.kv_transfer_params["remote_handshake_port"]=4789
|
|
||||||
request.kv_transfer_params["remote_engine_id"]="test_engine"
|
|
||||||
# Remote Prefill, triggers NixlConnectorMetadata.
|
# Remote Prefill, triggers NixlConnectorMetadata.
|
||||||
scheduler_output = scheduler.schedule()
|
scheduler_output = scheduler.schedule()
|
||||||
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
||||||
@ -338,8 +303,8 @@ def test_write_mode_chunk_prefill():
|
|||||||
):
|
):
|
||||||
assert block_id == block.block_id
|
assert block_id == block.block_id
|
||||||
|
|
||||||
def test_read_mode_basic_interface(moriio_read_mode):
|
|
||||||
|
|
||||||
|
def test_read_mode_basic_interface(moriio_read_mode):
|
||||||
# test decode read
|
# test decode read
|
||||||
vllm_config = create_vllm_config(role="kv_consumer")
|
vllm_config = create_vllm_config(role="kv_consumer")
|
||||||
scheduler = create_scheduler(vllm_config)
|
scheduler = create_scheduler(vllm_config)
|
||||||
@ -354,20 +319,20 @@ def test_read_mode_basic_interface(moriio_read_mode):
|
|||||||
block_size=BLOCK_SIZE,
|
block_size=BLOCK_SIZE,
|
||||||
num_tokens=NUM_TOKENS,
|
num_tokens=NUM_TOKENS,
|
||||||
do_remote_decode=False,
|
do_remote_decode=False,
|
||||||
do_remote_prefill=True
|
do_remote_prefill=True,
|
||||||
)
|
)
|
||||||
request_id = request.request_id
|
request_id = request.request_id
|
||||||
|
|
||||||
scheduler.add_request(request)
|
scheduler.add_request(request)
|
||||||
block_list = scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[request_id]
|
block_list = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||||
# Fake
|
0
|
||||||
request.kv_transfer_params['remote_notify_port']=4789
|
].req_to_blocks[request_id]
|
||||||
request.kv_transfer_params['remote_block_ids']=block_list
|
# Fake kv config
|
||||||
request.kv_transfer_params["remote_host"]="127.0.0.1"
|
request = _setup_kv_transfer_request(request)
|
||||||
request.kv_transfer_params["remote_port"]=4789
|
request.kv_transfer_params["remote_block_ids"] = block_list
|
||||||
request.kv_transfer_params["remote_handshake_port"]=4789
|
|
||||||
request.kv_transfer_params["remote_engine_id"]="test_engine"
|
|
||||||
# Remote Prefill, triggers MorIIOConnectorMetadata.
|
# Remote Prefill, triggers MorIIOConnectorMetadata.
|
||||||
|
|
||||||
scheduler_output = scheduler.schedule()
|
scheduler_output = scheduler.schedule()
|
||||||
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
||||||
assert kv_connector_metadata is not None
|
assert kv_connector_metadata is not None
|
||||||
@ -375,7 +340,6 @@ def test_read_mode_basic_interface(moriio_read_mode):
|
|||||||
assert len(kv_connector_metadata.reqs_to_save) == 0
|
assert len(kv_connector_metadata.reqs_to_save) == 0
|
||||||
assert len(kv_connector_metadata.reqs_to_recv) == 1
|
assert len(kv_connector_metadata.reqs_to_recv) == 1
|
||||||
assert len(kv_connector_metadata.reqs_to_send) == 0
|
assert len(kv_connector_metadata.reqs_to_send) == 0
|
||||||
# assert len(kv_connector_metadata.reqs_to_save) == 1
|
|
||||||
assert request_id in kv_connector_metadata.reqs_to_recv
|
assert request_id in kv_connector_metadata.reqs_to_recv
|
||||||
req_meta = kv_connector_metadata.reqs_to_recv[request_id]
|
req_meta = kv_connector_metadata.reqs_to_recv[request_id]
|
||||||
|
|
||||||
@ -388,16 +352,15 @@ def test_read_mode_basic_interface(moriio_read_mode):
|
|||||||
assert block_id == block.block_id
|
assert block_id == block.block_id
|
||||||
|
|
||||||
|
|
||||||
def test_register_kv_caches():
|
def test_register_kv_caches(mock_parallel_groups):
|
||||||
from vllm.utils.network_utils import get_ip
|
ROLE = "kv_consumer"
|
||||||
|
IP = get_ip()
|
||||||
ROLE="kv_consumer"
|
|
||||||
IP=get_ip()
|
|
||||||
DEFAULT_PORT=6301
|
|
||||||
vllm_config = create_vllm_config(role=ROLE)
|
vllm_config = create_vllm_config(role=ROLE)
|
||||||
TP_RANK=0
|
DEFAULT_PORT = 6301
|
||||||
DP_RANK=0
|
TP_RANK = 0
|
||||||
from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend
|
DP_RANK = 0
|
||||||
|
from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend
|
||||||
|
|
||||||
backend_cls = AiterFlashAttentionBackend
|
backend_cls = AiterFlashAttentionBackend
|
||||||
|
|
||||||
# Create test kv cache tensors using proper backend shape
|
# Create test kv cache tensors using proper backend shape
|
||||||
@ -412,96 +375,79 @@ def test_register_kv_caches():
|
|||||||
"layer2": shared_tensor,
|
"layer2": shared_tensor,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Store tensor info for validation
|
|
||||||
expected_tensor_size = shared_tensor[0].element_size() * shared_tensor[0].numel()
|
|
||||||
expected_base_addrs = [
|
|
||||||
shared_tensor[0].data_ptr(),
|
|
||||||
unique_tensor[1].data_ptr(),
|
|
||||||
shared_tensor[0].data_ptr(),
|
|
||||||
|
|
||||||
]
|
|
||||||
mock_group = MagicMock()
|
|
||||||
mock_group.rank = TP_RANK # 设置 rank
|
|
||||||
mock_group.local_rank = TP_RANK
|
|
||||||
mock_group.world_size = 1 # 设置 world_size
|
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_engine.MoRIIOWrapper"
|
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.threading.Event"
|
||||||
) as mock_moriio_wrapper,
|
|
||||||
patch(
|
|
||||||
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common.get_tensor_model_parallel_rank",
|
|
||||||
return_value=0
|
|
||||||
),
|
|
||||||
patch(
|
|
||||||
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common.get_tensor_model_parallel_world_size",
|
|
||||||
return_value=0
|
|
||||||
),
|
|
||||||
patch(
|
|
||||||
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_tensor_model_parallel_world_size",
|
|
||||||
return_value=0
|
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_world_group",
|
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.threading.Thread"
|
||||||
return_value=mock_group
|
|
||||||
),
|
),
|
||||||
patch(
|
|
||||||
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_tp_group",
|
|
||||||
return_value=mock_group
|
|
||||||
),
|
|
||||||
patch(
|
|
||||||
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_engine.MoRIIOWrapper",
|
|
||||||
FakeMorIIOWrapper,
|
|
||||||
)
|
|
||||||
|
|
||||||
):
|
):
|
||||||
# Create connector
|
# Create connector
|
||||||
vllm_config.kv_transfer_config.kv_connector_extra_config.update({
|
vllm_config.kv_transfer_config.kv_connector_extra_config.update(
|
||||||
"proxy_ip": "127.0.0.1",
|
{
|
||||||
"proxy_ping_port": 12345,
|
"proxy_ip": "127.0.0.1",
|
||||||
"http_port": 12346,
|
"proxy_ping_port": 12345,
|
||||||
})
|
"http_port": 12346,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
connector = MoRIIOConnector(vllm_config, KVConnectorRole.WORKER)
|
connector = MoRIIOConnector(vllm_config, KVConnectorRole.WORKER)
|
||||||
connector.connector_worker = FakeMoriIOConnectorWorker(
|
connector.connector_worker = FakeMorIIOConnectorWorker(
|
||||||
vllm_config, connector.engine_id, hand_shake_latency=0
|
vllm_config, connector.engine_id, hand_shake_latency=0
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the mock instance
|
|
||||||
mock_wrapper_instance = mock_moriio_wrapper.return_value
|
|
||||||
# connector.connector_worker.moriio_wrapper = mock_wrapper_instance
|
|
||||||
|
|
||||||
# Reassure the shutdown() check that the thread is terminated
|
|
||||||
# mock_thread.return_value.is_alive.return_value = False
|
|
||||||
from mori.io import (
|
from mori.io import (
|
||||||
EngineDesc,
|
|
||||||
IOEngine,
|
|
||||||
MemoryDesc,
|
MemoryDesc,
|
||||||
PollCqMode,
|
|
||||||
RdmaBackendConfig,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute register_kv_caches
|
# Execute register_kv_caches
|
||||||
connector.register_kv_caches(kv_caches)
|
connector.register_kv_caches(kv_caches)
|
||||||
shared_tensor[0].data_ptr()
|
shared_tensor[0].data_ptr()
|
||||||
unique_tensor[1].data_ptr()
|
unique_tensor[1].data_ptr()
|
||||||
shared_tensor[0].data_ptr()
|
shared_tensor[0].data_ptr()
|
||||||
|
|
||||||
assert shared_tensor.data_ptr()==MemoryDesc.unpack(connector.connector_worker.layer_name_to_local_kv_cache_metadata["layer0"][0]).data
|
assert (
|
||||||
assert unique_tensor.data_ptr()==MemoryDesc.unpack(connector.connector_worker.layer_name_to_local_kv_cache_metadata["layer1"][0]).data
|
shared_tensor.data_ptr()
|
||||||
assert shared_tensor.data_ptr()==MemoryDesc.unpack(connector.connector_worker.layer_name_to_local_kv_cache_metadata["layer2"][0]).data
|
== MemoryDesc.unpack(
|
||||||
|
connector.connector_worker.layer_name_to_local_kv_cache_metadata[
|
||||||
|
"layer0"
|
||||||
|
][0]
|
||||||
|
).data
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
unique_tensor.data_ptr()
|
||||||
|
== MemoryDesc.unpack(
|
||||||
|
connector.connector_worker.layer_name_to_local_kv_cache_metadata[
|
||||||
|
"layer1"
|
||||||
|
][0]
|
||||||
|
).data
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
shared_tensor.data_ptr()
|
||||||
|
== MemoryDesc.unpack(
|
||||||
|
connector.connector_worker.layer_name_to_local_kv_cache_metadata[
|
||||||
|
"layer2"
|
||||||
|
][0]
|
||||||
|
).data
|
||||||
|
)
|
||||||
expected_engine_key = f"{ROLE[3:]}:{IP}:{DEFAULT_PORT}:tp{TP_RANK}:dp{DP_RANK}"
|
expected_engine_key = f"{ROLE[3:]}:{IP}:{DEFAULT_PORT}:tp{TP_RANK}:dp{DP_RANK}"
|
||||||
|
|
||||||
assert MemoryDesc.unpack(connector.connector_worker.layer_name_to_local_kv_cache_metadata["layer0"][0]).engine_key ==expected_engine_key
|
assert (
|
||||||
|
MemoryDesc.unpack(
|
||||||
|
connector.connector_worker.layer_name_to_local_kv_cache_metadata[
|
||||||
|
"layer0"
|
||||||
|
][0]
|
||||||
|
).engine_key
|
||||||
|
== expected_engine_key
|
||||||
|
)
|
||||||
|
|
||||||
def test_moriio_handshake():
|
|
||||||
from vllm.utils.network_utils import get_ip
|
|
||||||
|
|
||||||
ROLE="kv_consumer"
|
def test_moriio_handshake(mock_parallel_groups):
|
||||||
IP=get_ip()
|
ROLE = "kv_consumer"
|
||||||
DEFAULT_PORT=6301
|
|
||||||
vllm_config = create_vllm_config(role=ROLE)
|
vllm_config = create_vllm_config(role=ROLE)
|
||||||
TP_RANK=0
|
from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend
|
||||||
DP_RANK=0
|
|
||||||
from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend
|
|
||||||
backend_cls = AiterFlashAttentionBackend
|
backend_cls = AiterFlashAttentionBackend
|
||||||
|
|
||||||
# Create test kv cache tensors using proper backend shape
|
# Create test kv cache tensors using proper backend shape
|
||||||
@ -516,90 +462,36 @@ def test_moriio_handshake():
|
|||||||
"layer2": shared_tensor,
|
"layer2": shared_tensor,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Store tensor info for validation
|
|
||||||
expected_tensor_size = shared_tensor[0].element_size() * shared_tensor[0].numel()
|
|
||||||
expected_base_addrs = [
|
|
||||||
shared_tensor[0].data_ptr(),
|
|
||||||
unique_tensor[1].data_ptr(),
|
|
||||||
shared_tensor[0].data_ptr(),
|
|
||||||
|
|
||||||
]
|
|
||||||
mock_group = MagicMock()
|
|
||||||
mock_group.rank = TP_RANK
|
|
||||||
mock_group.local_rank = TP_RANK
|
|
||||||
mock_group.world_size = 1
|
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common.get_tensor_model_parallel_rank",
|
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_engine.MoRIIOWrapper",
|
||||||
return_value=0
|
FakeMorIIOWrapper,
|
||||||
),
|
),
|
||||||
patch(
|
|
||||||
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common.get_tensor_model_parallel_world_size",
|
|
||||||
return_value=0
|
|
||||||
),
|
|
||||||
patch(
|
|
||||||
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_tensor_model_parallel_world_size",
|
|
||||||
return_value=0
|
|
||||||
),
|
|
||||||
patch(
|
|
||||||
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_world_group",
|
|
||||||
return_value=mock_group
|
|
||||||
),
|
|
||||||
patch(
|
|
||||||
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_tp_group",
|
|
||||||
return_value=mock_group
|
|
||||||
),
|
|
||||||
patch(
|
|
||||||
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_engine.MoRIIOWrapper",
|
|
||||||
FakeMorIIOWrapper,
|
|
||||||
)
|
|
||||||
|
|
||||||
):
|
):
|
||||||
|
handshake_port = _find_free_port()
|
||||||
# Create connector
|
# Create connector
|
||||||
vllm_config.kv_transfer_config.kv_connector_extra_config.update({
|
vllm_config.kv_transfer_config.kv_connector_extra_config.update(
|
||||||
"proxy_ip": "127.0.0.1",
|
{
|
||||||
"proxy_ping_port": 12345,
|
"proxy_ip": "127.0.0.1",
|
||||||
"http_port": 12346,
|
"proxy_ping_port": 12345,
|
||||||
"handshake_port":5670
|
"http_port": 12346,
|
||||||
})
|
"handshake_port": handshake_port,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
connector = MoRIIOConnector(vllm_config, KVConnectorRole.WORKER)
|
connector = MoRIIOConnector(vllm_config, KVConnectorRole.WORKER)
|
||||||
|
|
||||||
from vllm.utils.network_utils import (
|
|
||||||
get_ip,
|
|
||||||
make_zmq_path,
|
|
||||||
make_zmq_socket,
|
|
||||||
)
|
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import zmq_ctx
|
|
||||||
import zmq
|
|
||||||
|
|
||||||
|
|
||||||
# Reassure the shutdown() check that the thread is terminated
|
|
||||||
# mock_thread.return_value.is_alive.return_value = False
|
|
||||||
from mori.io import (
|
|
||||||
EngineDesc,
|
|
||||||
IOEngine,
|
|
||||||
MemoryDesc,
|
|
||||||
PollCqMode,
|
|
||||||
RdmaBackendConfig,
|
|
||||||
)
|
|
||||||
# Execute register_kv_caches
|
# Execute register_kv_caches
|
||||||
connector.register_kv_caches(kv_caches)
|
connector.register_kv_caches(kv_caches)
|
||||||
# connector.layer_name_to_local_kv_cache_metadata["layer0"] expected_base_addrs = [
|
path = make_zmq_path("tcp", "127.0.0.1", handshake_port)
|
||||||
|
|
||||||
path = make_zmq_path("tcp", "127.0.0.1", 5670)
|
|
||||||
with zmq_ctx(zmq.DEALER, path) as sock:
|
with zmq_ctx(zmq.DEALER, path) as sock:
|
||||||
sock.send(MoRIIOConstants.GET_META_MSG)
|
sock.send(MoRIIOConstants.GET_META_MSG)
|
||||||
received_frame = sock.recv_multipart()
|
received_frame = sock.recv_multipart()
|
||||||
|
|
||||||
if len(received_frame) != 2 or received_frame[0] != b"":
|
if len(received_frame) != 2 or received_frame[0] != b"":
|
||||||
raise HandshakeError(f"Unexpected frame! {received_frame = }")
|
raise ValueError(f"Unexpected frame! {received_frame = }")
|
||||||
|
|
||||||
metadata_bytes = received_frame[1]
|
metadata_bytes = received_frame[1]
|
||||||
decoder = msgspec.msgpack.Decoder(MoRIIOAgentMetadata)
|
decoder = msgspec.msgpack.Decoder(MoRIIOAgentMetadata)
|
||||||
metadata = decoder.decode(metadata_bytes)
|
metadata = decoder.decode(metadata_bytes)
|
||||||
assert isinstance(metadata, MoRIIOAgentMetadata)
|
assert isinstance(metadata, MoRIIOAgentMetadata)
|
||||||
|
|
||||||
|
|
||||||
Loading…
x
Reference in New Issue
Block a user