mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-03 11:57:02 +08:00
add basic test
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
parent
a0330452d5
commit
c4dcb3475e
605
tests/v1/kv_connector/unit/test_moriio_connector.py
Normal file
605
tests/v1/kv_connector/unit/test_moriio_connector.py
Normal file
@ -0,0 +1,605 @@
|
||||
import pytest
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
import contextlib
|
||||
import inspect
|
||||
import os
|
||||
import tempfile
|
||||
import textwrap
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock
|
||||
import pytest
|
||||
import ray
|
||||
import torch
|
||||
import msgspec
|
||||
from vllm import LLM
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
|
||||
MultiKVConnectorStats,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector import (
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import MoRIIOConstants
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import MoRIIOAgentMetadata
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector import (
|
||||
MoRIIOConnector,
|
||||
MoRIIOConnectorScheduler,
|
||||
MoRIIOConnectorWorker
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import MoRIIOConnectorMetadata
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_transfer_state import (
|
||||
ensure_kv_transfer_shutdown,
|
||||
has_kv_transfer_group,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.platforms.interface import Platform
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
||||
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
||||
from vllm.v1.request import RequestStatus
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
|
||||
from .utils import create_request, create_scheduler
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
DeviceConfig,
|
||||
KVTransferConfig,
|
||||
ModelConfig,
|
||||
SchedulerConfig,
|
||||
VllmConfig,
|
||||
)
|
||||
|
||||
class FakeMorIIOWrapper():
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
def set_moriio_engine(self, moriio_engine):
|
||||
pass
|
||||
|
||||
def set_backend_type(self, backend_type):
|
||||
pass
|
||||
|
||||
def get_agent_metadata(self):
|
||||
pass
|
||||
|
||||
def register_remote_engine(self, remote_packed_engine_metadata):
|
||||
pass
|
||||
|
||||
def register_local_tensor(self, tensor: torch.Tensor):
|
||||
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
||||
try:
|
||||
self.local_memory_metadata = self.moriio_engine.register_torch_tensor(
|
||||
tensor
|
||||
)
|
||||
assert self.local_memory_metadata is not None, (
|
||||
"register_torch_tensor returned None"
|
||||
)
|
||||
local_memory_metadata_packed = self.local_memory_metadata.pack()
|
||||
except Exception as e:
|
||||
raise MoRIIOError(f"Failed to register local memory: {e}") from e
|
||||
self.local_memory_registered = True
|
||||
return local_memory_metadata_packed
|
||||
|
||||
def get_unpack_memory_metadata(self, packed_memory_metadata):
|
||||
pass
|
||||
|
||||
def build_session(self, local_memory_metadata, remote_memory_metadata):
|
||||
pass
|
||||
|
||||
def read_remote_data(
|
||||
self, transfer_size_byte, local_offset=0, remote_offset=0, session=None
|
||||
):
|
||||
pass
|
||||
|
||||
def write_remote_data(
|
||||
self, transfer_size_byte, local_offset=0, remote_offset=0, session=None
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
def write_remote_data_single(
|
||||
self, transfer_size_byte, local_offset=0, remote_offset=0, sess_idx=0
|
||||
):
|
||||
pass
|
||||
|
||||
def waiting_for_transfer_complete(self):
|
||||
pass
|
||||
|
||||
def async_wait_reqid(self):
|
||||
pass
|
||||
|
||||
def _handle_message(self, msg: bytes):
|
||||
pass
|
||||
|
||||
def _handle_structured_message(self, data: dict):
|
||||
pass
|
||||
|
||||
def _handle_completion_message(self, msg: str):
|
||||
pass
|
||||
|
||||
def send_notify(self, req_ids, remote_ip, remote_port):
|
||||
pass
|
||||
|
||||
def pop_finished_req_ids(self):
|
||||
pass
|
||||
|
||||
def pop_finished_write_req_ids(self):
|
||||
pass
|
||||
|
||||
def shutdown(self):
|
||||
pass
|
||||
|
||||
|
||||
class FakeMoriIOConnectorWorker(MoRIIOConnectorWorker):
|
||||
REMOTE_ENGINE_ID = "remote_engine"
|
||||
|
||||
def __init__(
|
||||
self, *args, hand_shake_latency: float = 1.8, kv_cache_layout="HND", **kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._hand_shake_latency = hand_shake_latency
|
||||
self.kv_cache_layout = kv_cache_layout
|
||||
|
||||
|
||||
from vllm.v1.core.sched.scheduler import Scheduler, SchedulerOutput
|
||||
|
||||
def create_vllm_config(
|
||||
model: str = "facebook/opt-125m",
|
||||
max_num_seqs: int = 16,
|
||||
max_num_batched_tokens: int = 64,
|
||||
block_size: int = 16,
|
||||
max_model_len: int = 10000,
|
||||
enable_chunked_prefill: bool = True,
|
||||
enable_permute_local_kv: bool = False,
|
||||
role="kv_consumer"
|
||||
# role="kv_producer"
|
||||
|
||||
) -> VllmConfig:
|
||||
"""Initialize VllmConfig For Testing."""
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_model_len=max_model_len,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
)
|
||||
model_config = ModelConfig(
|
||||
model=model,
|
||||
trust_remote_code=True,
|
||||
dtype="bfloat16",
|
||||
seed=42,
|
||||
)
|
||||
# Cache config, optionally force APC
|
||||
cache_config = CacheConfig(
|
||||
block_size=block_size,
|
||||
gpu_memory_utilization=0.9,
|
||||
swap_space=0,
|
||||
cache_dtype="auto",
|
||||
enable_prefix_caching=True,
|
||||
)
|
||||
kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="MoRIIOConnector",
|
||||
kv_role=role,
|
||||
enable_permute_local_kv=enable_permute_local_kv,
|
||||
)
|
||||
return VllmConfig(
|
||||
scheduler_config=scheduler_config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
device_config=DeviceConfig("cpu"),
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
FullAttentionSpec,
|
||||
KVCacheConfig,
|
||||
KVCacheGroupSpec,
|
||||
)
|
||||
def create_scheduler(
|
||||
vllm_config: VllmConfig,
|
||||
num_blocks: int = 10000,
|
||||
) -> Scheduler:
|
||||
"""Initialize Scheduler For Testing."""
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=num_blocks, # A large number of blocks to hold all requests
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False)
|
||||
)
|
||||
],
|
||||
)
|
||||
vllm_config.cache_config.num_gpu_blocks = num_blocks
|
||||
return Scheduler(
|
||||
vllm_config=vllm_config,
|
||||
kv_cache_config=kv_cache_config,
|
||||
log_stats=True,
|
||||
structured_output_manager=StructuredOutputManager(vllm_config),
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def moriio_read_mode():
|
||||
"""Set VLLM_MORIIO_CONNECTOR_READ_MODE environment variable for all tests."""
|
||||
os.environ['VLLM_MORIIO_CONNECTOR_READ_MODE'] = 'True'
|
||||
yield
|
||||
# Cleanup after test
|
||||
os.environ.pop('VLLM_MORIIO_CONNECTOR_READ_MODE', None)
|
||||
|
||||
def test_write_mode_basic_interface():
|
||||
"""Unit test for basic MoriioConnector interface functionality."""
|
||||
|
||||
# Test Prefill wirte metadata
|
||||
vllm_config = create_vllm_config(role="kv_consumer")
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# 2 Full Blocks and 1 Half Block.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request = create_request(
|
||||
request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_decode=True,
|
||||
do_remote_prefill=False
|
||||
)
|
||||
request_id = request.request_id
|
||||
|
||||
scheduler.add_request(request)
|
||||
|
||||
# Fake
|
||||
request.kv_transfer_params['remote_notify_port']=4789
|
||||
request.kv_transfer_params['remote_block_ids']=None
|
||||
request.kv_transfer_params["remote_host"]="127.0.0.1"
|
||||
request.kv_transfer_params["remote_port"]=4789
|
||||
request.kv_transfer_params["remote_handshake_port"]=4789
|
||||
request.kv_transfer_params["remote_engine_id"]="test_engine"
|
||||
# Remote Prefill, triggers NixlConnectorMetadata.
|
||||
scheduler_output = scheduler.schedule()
|
||||
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
||||
assert kv_connector_metadata is not None
|
||||
assert isinstance(kv_connector_metadata, MoRIIOConnectorMetadata)
|
||||
|
||||
assert len(kv_connector_metadata.reqs_to_save) == 1
|
||||
assert len(kv_connector_metadata.reqs_to_recv) == 0
|
||||
assert len(kv_connector_metadata.reqs_to_send) == 0
|
||||
assert request_id in kv_connector_metadata.reqs_to_save
|
||||
req_meta = kv_connector_metadata.reqs_to_save[request_id]
|
||||
|
||||
for block_id, block in zip(
|
||||
req_meta.local_block_ids,
|
||||
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[
|
||||
request_id
|
||||
],
|
||||
):
|
||||
assert block_id == block.block_id
|
||||
|
||||
|
||||
def test_write_mode_chunk_prefill():
|
||||
"""Unit test for basic MoriioConnector interface functionality."""
|
||||
MAX_NUM_BATCHED_TOKENS=64
|
||||
|
||||
NUM_TOKENS = MAX_NUM_BATCHED_TOKENS*2+MAX_NUM_BATCHED_TOKENS//2
|
||||
|
||||
# Test Prefill wirte metadata
|
||||
vllm_config = create_vllm_config(max_num_batched_tokens=MAX_NUM_BATCHED_TOKENS, role="kv_consumer")
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# 2 Full Blocks and 1 Half Block.
|
||||
|
||||
request = create_request(
|
||||
request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_decode=True,
|
||||
do_remote_prefill=False
|
||||
)
|
||||
request_id = request.request_id
|
||||
|
||||
scheduler.add_request(request)
|
||||
|
||||
# Fake
|
||||
request.kv_transfer_params['remote_notify_port']=4789
|
||||
request.kv_transfer_params['remote_block_ids']=None
|
||||
request.kv_transfer_params["remote_host"]="127.0.0.1"
|
||||
request.kv_transfer_params["remote_port"]=4789
|
||||
request.kv_transfer_params["remote_handshake_port"]=4789
|
||||
request.kv_transfer_params["remote_engine_id"]="test_engine"
|
||||
# Remote Prefill, triggers NixlConnectorMetadata.
|
||||
scheduler_output = scheduler.schedule()
|
||||
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
||||
assert kv_connector_metadata is not None
|
||||
assert isinstance(kv_connector_metadata, MoRIIOConnectorMetadata)
|
||||
|
||||
assert len(kv_connector_metadata.reqs_to_save) == 1
|
||||
assert len(kv_connector_metadata.reqs_to_recv) == 0
|
||||
assert len(kv_connector_metadata.reqs_to_send) == 0
|
||||
assert request_id in kv_connector_metadata.reqs_to_save
|
||||
req_meta = kv_connector_metadata.reqs_to_save[request_id]
|
||||
|
||||
for block_id, block in zip(
|
||||
req_meta.local_block_ids,
|
||||
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[
|
||||
request_id
|
||||
],
|
||||
):
|
||||
assert block_id == block.block_id
|
||||
|
||||
def test_read_mode_basic_interface(moriio_read_mode):
|
||||
|
||||
# test decode read
|
||||
vllm_config = create_vllm_config(role="kv_consumer")
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
#
|
||||
# 2 Full Blocks and 1 Half Block.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request = create_request(
|
||||
request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_decode=False,
|
||||
do_remote_prefill=True
|
||||
)
|
||||
request_id = request.request_id
|
||||
|
||||
scheduler.add_request(request)
|
||||
block_list = scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[request_id]
|
||||
# Fake
|
||||
request.kv_transfer_params['remote_notify_port']=4789
|
||||
request.kv_transfer_params['remote_block_ids']=block_list
|
||||
request.kv_transfer_params["remote_host"]="127.0.0.1"
|
||||
request.kv_transfer_params["remote_port"]=4789
|
||||
request.kv_transfer_params["remote_handshake_port"]=4789
|
||||
request.kv_transfer_params["remote_engine_id"]="test_engine"
|
||||
# Remote Prefill, triggers MorIIOConnectorMetadata.
|
||||
scheduler_output = scheduler.schedule()
|
||||
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
||||
assert kv_connector_metadata is not None
|
||||
assert isinstance(kv_connector_metadata, MoRIIOConnectorMetadata)
|
||||
assert len(kv_connector_metadata.reqs_to_save) == 0
|
||||
assert len(kv_connector_metadata.reqs_to_recv) == 1
|
||||
assert len(kv_connector_metadata.reqs_to_send) == 0
|
||||
# assert len(kv_connector_metadata.reqs_to_save) == 1
|
||||
assert request_id in kv_connector_metadata.reqs_to_recv
|
||||
req_meta = kv_connector_metadata.reqs_to_recv[request_id]
|
||||
|
||||
for block_id, block in zip(
|
||||
req_meta.local_block_ids,
|
||||
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[
|
||||
request_id
|
||||
],
|
||||
):
|
||||
assert block_id == block.block_id
|
||||
|
||||
|
||||
def test_register_kv_caches():
|
||||
from vllm.utils.network_utils import get_ip
|
||||
|
||||
ROLE="kv_consumer"
|
||||
IP=get_ip()
|
||||
DEFAULT_PORT=6301
|
||||
vllm_config = create_vllm_config(role=ROLE)
|
||||
TP_RANK=0
|
||||
DP_RANK=0
|
||||
from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend
|
||||
backend_cls = AiterFlashAttentionBackend
|
||||
|
||||
# Create test kv cache tensors using proper backend shape
|
||||
kv_cache_shape = backend_cls.get_kv_cache_shape(
|
||||
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
|
||||
)
|
||||
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
|
||||
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
|
||||
kv_caches = {
|
||||
"layer0": shared_tensor,
|
||||
"layer1": unique_tensor,
|
||||
"layer2": shared_tensor,
|
||||
}
|
||||
|
||||
# Store tensor info for validation
|
||||
expected_tensor_size = shared_tensor[0].element_size() * shared_tensor[0].numel()
|
||||
expected_base_addrs = [
|
||||
shared_tensor[0].data_ptr(),
|
||||
unique_tensor[1].data_ptr(),
|
||||
shared_tensor[0].data_ptr(),
|
||||
|
||||
]
|
||||
mock_group = MagicMock()
|
||||
mock_group.rank = TP_RANK # 设置 rank
|
||||
mock_group.local_rank = TP_RANK
|
||||
mock_group.world_size = 1 # 设置 world_size
|
||||
with (
|
||||
patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_engine.MoRIIOWrapper"
|
||||
) as mock_moriio_wrapper,
|
||||
patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common.get_tensor_model_parallel_rank",
|
||||
return_value=0
|
||||
),
|
||||
patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common.get_tensor_model_parallel_world_size",
|
||||
return_value=0
|
||||
),
|
||||
patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_tensor_model_parallel_world_size",
|
||||
return_value=0
|
||||
),
|
||||
patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_world_group",
|
||||
return_value=mock_group
|
||||
),
|
||||
patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_tp_group",
|
||||
return_value=mock_group
|
||||
),
|
||||
patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_engine.MoRIIOWrapper",
|
||||
FakeMorIIOWrapper,
|
||||
)
|
||||
|
||||
):
|
||||
# Create connector
|
||||
vllm_config.kv_transfer_config.kv_connector_extra_config.update({
|
||||
"proxy_ip": "127.0.0.1",
|
||||
"proxy_ping_port": 12345,
|
||||
"http_port": 12346,
|
||||
})
|
||||
|
||||
connector = MoRIIOConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector.connector_worker = FakeMoriIOConnectorWorker(
|
||||
vllm_config, connector.engine_id, hand_shake_latency=0
|
||||
)
|
||||
|
||||
# Get the mock instance
|
||||
mock_wrapper_instance = mock_moriio_wrapper.return_value
|
||||
# connector.connector_worker.moriio_wrapper = mock_wrapper_instance
|
||||
|
||||
# Reassure the shutdown() check that the thread is terminated
|
||||
# mock_thread.return_value.is_alive.return_value = False
|
||||
from mori.io import (
|
||||
EngineDesc,
|
||||
IOEngine,
|
||||
MemoryDesc,
|
||||
PollCqMode,
|
||||
RdmaBackendConfig,
|
||||
)
|
||||
# Execute register_kv_caches
|
||||
connector.register_kv_caches(kv_caches)
|
||||
shared_tensor[0].data_ptr()
|
||||
unique_tensor[1].data_ptr()
|
||||
shared_tensor[0].data_ptr()
|
||||
|
||||
assert shared_tensor.data_ptr()==MemoryDesc.unpack(connector.connector_worker.layer_name_to_local_kv_cache_metadata["layer0"][0]).data
|
||||
assert unique_tensor.data_ptr()==MemoryDesc.unpack(connector.connector_worker.layer_name_to_local_kv_cache_metadata["layer1"][0]).data
|
||||
assert shared_tensor.data_ptr()==MemoryDesc.unpack(connector.connector_worker.layer_name_to_local_kv_cache_metadata["layer2"][0]).data
|
||||
expected_engine_key = f"{ROLE[3:]}:{IP}:{DEFAULT_PORT}:tp{TP_RANK}:dp{DP_RANK}"
|
||||
|
||||
assert MemoryDesc.unpack(connector.connector_worker.layer_name_to_local_kv_cache_metadata["layer0"][0]).engine_key ==expected_engine_key
|
||||
|
||||
def test_moriio_handshake():
|
||||
from vllm.utils.network_utils import get_ip
|
||||
|
||||
ROLE="kv_consumer"
|
||||
IP=get_ip()
|
||||
DEFAULT_PORT=6301
|
||||
vllm_config = create_vllm_config(role=ROLE)
|
||||
TP_RANK=0
|
||||
DP_RANK=0
|
||||
from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend
|
||||
backend_cls = AiterFlashAttentionBackend
|
||||
|
||||
# Create test kv cache tensors using proper backend shape
|
||||
kv_cache_shape = backend_cls.get_kv_cache_shape(
|
||||
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
|
||||
)
|
||||
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
|
||||
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
|
||||
kv_caches = {
|
||||
"layer0": shared_tensor,
|
||||
"layer1": unique_tensor,
|
||||
"layer2": shared_tensor,
|
||||
}
|
||||
|
||||
# Store tensor info for validation
|
||||
expected_tensor_size = shared_tensor[0].element_size() * shared_tensor[0].numel()
|
||||
expected_base_addrs = [
|
||||
shared_tensor[0].data_ptr(),
|
||||
unique_tensor[1].data_ptr(),
|
||||
shared_tensor[0].data_ptr(),
|
||||
|
||||
]
|
||||
mock_group = MagicMock()
|
||||
mock_group.rank = TP_RANK
|
||||
mock_group.local_rank = TP_RANK
|
||||
mock_group.world_size = 1
|
||||
with (
|
||||
patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common.get_tensor_model_parallel_rank",
|
||||
return_value=0
|
||||
),
|
||||
patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common.get_tensor_model_parallel_world_size",
|
||||
return_value=0
|
||||
),
|
||||
patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_tensor_model_parallel_world_size",
|
||||
return_value=0
|
||||
),
|
||||
patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_world_group",
|
||||
return_value=mock_group
|
||||
),
|
||||
patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.get_tp_group",
|
||||
return_value=mock_group
|
||||
),
|
||||
patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_engine.MoRIIOWrapper",
|
||||
FakeMorIIOWrapper,
|
||||
)
|
||||
|
||||
):
|
||||
# Create connector
|
||||
vllm_config.kv_transfer_config.kv_connector_extra_config.update({
|
||||
"proxy_ip": "127.0.0.1",
|
||||
"proxy_ping_port": 12345,
|
||||
"http_port": 12346,
|
||||
"handshake_port":5670
|
||||
})
|
||||
|
||||
|
||||
|
||||
connector = MoRIIOConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
|
||||
from vllm.utils.network_utils import (
|
||||
get_ip,
|
||||
make_zmq_path,
|
||||
make_zmq_socket,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import zmq_ctx
|
||||
import zmq
|
||||
|
||||
|
||||
# Reassure the shutdown() check that the thread is terminated
|
||||
# mock_thread.return_value.is_alive.return_value = False
|
||||
from mori.io import (
|
||||
EngineDesc,
|
||||
IOEngine,
|
||||
MemoryDesc,
|
||||
PollCqMode,
|
||||
RdmaBackendConfig,
|
||||
)
|
||||
# Execute register_kv_caches
|
||||
connector.register_kv_caches(kv_caches)
|
||||
# connector.layer_name_to_local_kv_cache_metadata["layer0"] expected_base_addrs = [
|
||||
|
||||
path = make_zmq_path("tcp", "127.0.0.1", 5670)
|
||||
with zmq_ctx(zmq.DEALER, path) as sock:
|
||||
sock.send(MoRIIOConstants.GET_META_MSG)
|
||||
received_frame = sock.recv_multipart()
|
||||
|
||||
if len(received_frame) != 2 or received_frame[0] != b"":
|
||||
raise HandshakeError(f"Unexpected frame! {received_frame = }")
|
||||
|
||||
metadata_bytes = received_frame[1]
|
||||
decoder = msgspec.msgpack.Decoder(MoRIIOAgentMetadata)
|
||||
metadata = decoder.decode(metadata_bytes)
|
||||
assert isinstance(metadata, MoRIIOAgentMetadata)
|
||||
|
||||
|
||||
@ -342,8 +342,8 @@ class MoRIIOConnectorScheduler:
|
||||
local_block_ids = blocks.get_block_ids()[0]
|
||||
self._reqs_need_save[request.request_id] = (request, local_block_ids)
|
||||
|
||||
if params is not None and params.get("do_remote_prefill"):
|
||||
if self.mode == MoRIIOMode.READ:
|
||||
if params is not None and params.get("do_remote_prefill"): #
|
||||
if self.mode == MoRIIOMode.READ: #read mode decode
|
||||
if remote_block_ids := params.get("remote_block_ids"):
|
||||
if all(
|
||||
p in params
|
||||
@ -373,7 +373,7 @@ class MoRIIOConnectorScheduler:
|
||||
)
|
||||
|
||||
else:
|
||||
assert request.kv_transfer_params is not None, (
|
||||
assert request.kv_transfer_params is not None, ( #write mode decode
|
||||
"kv_transfer_params should not be None"
|
||||
)
|
||||
|
||||
@ -890,7 +890,7 @@ class MoRIIOConnectorWorker:
|
||||
layer_name_to_local_kv_cache_metadata: dict,
|
||||
):
|
||||
"""Background thread for getting new MoRIIO handshakes."""
|
||||
|
||||
logger.info("tmp")
|
||||
encoder = msgspec.msgpack.Encoder()
|
||||
encoded_data = encoder.encode(metadata)
|
||||
size_in_bytes = len(encoded_data)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user