diff --git a/docker/Dockerfile.rocm_base b/docker/Dockerfile.rocm_base index c5e94ee1f6928..c820761b6b215 100644 --- a/docker/Dockerfile.rocm_base +++ b/docker/Dockerfile.rocm_base @@ -11,6 +11,8 @@ ARG FA_BRANCH="0e60e394" ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" ARG AITER_BRANCH="6af8b687" ARG AITER_REPO="https://github.com/ROCm/aiter.git" +ARG MORI_BRANCH="2d02c6a9" +ARG MORI_REPO="https://github.com/ROCm/mori.git" FROM ${BASE_IMAGE} AS base @@ -20,6 +22,7 @@ ENV LD_LIBRARY_PATH=/opt/rocm/lib:/usr/local/lib: ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942;gfx950;gfx1100;gfx1101;gfx1200;gfx1201;gfx1150;gfx1151 ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} ENV AITER_ROCM_ARCH=gfx942;gfx950 +ENV MORI_GPU_ARCHS=gfx942;gfx950 # Required for RCCL in ROCm7.1 ENV HSA_NO_SCRATCH_RECLAIM=1 @@ -33,7 +36,7 @@ ENV DEBIAN_FRONTEND=noninteractive # Install Python and other dependencies RUN apt-get update -y \ - && apt-get install -y software-properties-common git curl sudo vim less libgfortran5 \ + && apt-get install -y software-properties-common git curl sudo vim less libgfortran5 libopenmpi-dev libpci-dev \ && for i in 1 2 3; do \ add-apt-repository -y ppa:deadsnakes/ppa && break || \ { echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \ @@ -67,6 +70,18 @@ RUN cd /opt/rocm/share/amd_smi \ && pip wheel . --wheel-dir=dist RUN mkdir -p /app/install && cp /opt/rocm/share/amd_smi/dist/*.whl /app/install +FROM base AS build_mori +ARG MORI_BRANCH +ARG MORI_REPO +RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \ + pip install /install/*.whl +RUN git clone ${MORI_REPO} +RUN cd mori \ + && git checkout ${MORI_BRANCH} \ + && git submodule update --init --recursive \ + && python3 setup.py bdist_wheel --dist-dir=dist && ls /app/mori/dist/*.whl +RUN mkdir -p /app/install && cp /app/mori/dist/*.whl /app/install + FROM base AS build_pytorch ARG PYTORCH_BRANCH ARG PYTORCH_VISION_BRANCH @@ -132,6 +147,8 @@ RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \ cp /install/*.whl /app/debs RUN --mount=type=bind,from=build_aiter,src=/app/install/,target=/install \ cp /install/*.whl /app/debs +RUN --mount=type=bind,from=build_mori,src=/app/install/,target=/install \ + cp /install/*.whl /app/debs FROM base AS final RUN --mount=type=bind,from=debs,src=/app/debs,target=/install \ diff --git a/docs/getting_started/installation/gpu.rocm.inc.md b/docs/getting_started/installation/gpu.rocm.inc.md index 21120cc6fcd98..f2b8a7a811f9b 100644 --- a/docs/getting_started/installation/gpu.rocm.inc.md +++ b/docs/getting_started/installation/gpu.rocm.inc.md @@ -98,9 +98,24 @@ Currently, there are no pre-built ROCm wheels. !!! note - You will need to config the `$AITER_BRANCH_OR_COMMIT` for your purpose. - The validated `$AITER_BRANCH_OR_COMMIT` can be found in the [docker/Dockerfile.rocm_base](https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile.rocm_base). - -4. Build vLLM. For example, vLLM on ROCM 7.0 can be built with the following steps: + +4. If you want to use MORI for EP or PD disaggregation, you can install [MORI](https://github.com/ROCm/mori) using the following steps: + + ```bash + git clone https://github.com/ROCm/mori.git + cd mori + git checkout $MORI_BRANCH_OR_COMMIT + git submodule sync; git submodule update --init --recursive + MORI_GPU_ARCHS="gfx942;gfx950" python3 install . + ``` + + !!! note + - You will need to config the `$MORI_BRANCH_OR_COMMIT` for your purpose. + - The validated `$MORI_BRANCH_OR_COMMIT` can be found in the [docker/Dockerfile.rocm_base](https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile.rocm_base). + + +5. Build vLLM. For example, vLLM on ROCM 7.0 can be built with the following steps: ???+ console "Commands" diff --git a/examples/online_serving/disaggregated_serving/moriio_toy_proxy_server.py b/examples/online_serving/disaggregated_serving/moriio_toy_proxy_server.py new file mode 100644 index 0000000000000..a9feb82267f04 --- /dev/null +++ b/examples/online_serving/disaggregated_serving/moriio_toy_proxy_server.py @@ -0,0 +1,320 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio +import copy +import logging +import os +import re +import socket +import threading +import uuid + +import aiohttp +import msgpack +import zmq +from quart import Quart, make_response, request + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) +prefill_instances: list[dict] = [] +decode_instances: list[dict] = [] +request_nums = 0 +app = Quart(__name__) + +IP_PORT_PATTERN = re.compile(r"//(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}):(\d+)") + + +TRANSFER_TYPE = None + + +def _append_whole_dict_unique(target_list, data_dict): + new_filtered = {k: v for k, v in data_dict.items() if k != "index"} + for existed in target_list: + existed_filtered = {k: v for k, v in existed.items() if k != "index"} + if existed_filtered == new_filtered: + return False + print("!!APPEND!!", data_dict) + target_list.append(data_dict) + transfer_mode = data_dict.get("transfer_mode", "unknown") + global TRANSFER_TYPE + + if TRANSFER_TYPE is None: + TRANSFER_TYPE = transfer_mode + logger.info("SET TRANSFER TYPE TO %s", TRANSFER_TYPE) + elif transfer_mode != TRANSFER_TYPE: + raise ValueError(f"mismatched transfer mode {TRANSFER_TYPE} vs {transfer_mode}") + + return True + + +_list_lock = threading.RLock() + + +def _listen_for_register(hostname, port): + context = zmq.Context() + router_socket = context.socket(zmq.ROUTER) + router_socket.bind(f"tcp://{hostname}:{port}") + poller = zmq.Poller() + poller.register(router_socket, zmq.POLLIN) + global prefill_instances + global decode_instances + + while True: + socks = dict(poller.poll()) + if router_socket in socks: + remote_addr, msg = router_socket.recv_multipart() + data = msgpack.loads(msg) + if data["type"] == "HELLO": + pass + elif ( + data["type"] == "register" + and data["role"] == "P" + and data["request_address"] not in prefill_instances + ): + with _list_lock: + _append_whole_dict_unique(prefill_instances, data) + + elif ( + data["type"] == "register" + and data["role"] == "D" + and data["request_address"] not in decode_instances + ): + with _list_lock: + _append_whole_dict_unique(decode_instances, data) + + +def start_service_discovery(hostname, port): + if not hostname: + hostname = socket.gethostname() + if port == 0: + raise ValueError("Port cannot be 0") + + _listener_thread = threading.Thread( + target=_listen_for_register, args=(hostname, port), daemon=True + ) + _listener_thread.start() + return _listener_thread + + +async def send_request_to_prefill( + endpoint, req_data, request_id, p_endpoint, pip, pports, selected_prefill_dp_rank +): + req_data_copy = req_data + + req_data_copy["kv_transfer_params"].update( + { + "do_remote_decode": True, + "do_remote_prefill": False, + "remote_handshake_port": p_endpoint["handshake_port"], + "remote_notify_port": p_endpoint["notify_port"], + "remote_engine_id": None, + "remote_block_ids": None, + "remote_host": pip, + "remote_port": pports, + } + ) + req_data_copy["stream"] = False + req_data_copy["max_tokens"] = 1 + if "max_completion_tokens" in req_data_copy: + req_data_copy["max_completion_tokens"] = 1 + if "stream_options" in req_data_copy: + del req_data_copy["stream_options"] + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000) + ) as session: + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id, + } + if selected_prefill_dp_rank is not None: + headers["X-data-parallel-rank"] = str(selected_prefill_dp_rank) + async with session.post( + url=endpoint, json=req_data_copy, headers=headers + ) as response: + if response.status == 200: + return await response.json() + + else: + raise RuntimeError( + "send_request_to_prefill response.status != 200response.status = ", + response.status, + ) + + +async def start_decode_request(endpoint, req_data, request_id): + session = aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000) + ) + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id, + } + response = await session.post(url=endpoint, json=req_data, headers=headers) + return session, response + + +async def stream_decode_response(session, response, request_id): + try: + if response.status == 200: + async for chunk_bytes in response.content.iter_chunked(1024): + yield chunk_bytes + else: + raise RuntimeError( + f"decode response.status != 200, status = {response.status}" + ) + finally: + await session.close() + + +async def send_request_to_decode(endpoint, req_data, request_id): + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000) + ) as session: + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id, + } + async with session.post( + url=endpoint, json=req_data, headers=headers + ) as response: + if response.status == 200: + async for chunk_bytes in response.content.iter_chunked(1024): + yield chunk_bytes + else: + raise RuntimeError( + "send_request_to_decode response.status != 200,response.statuus = ", + response.status, + ) + + +def example_round_robin_dp_loader(request_number, dp_size): + return request_nums % dp_size + + +@app.route("/v1/completions", methods=["POST"]) +@app.route("/v1/chat/completions", methods=["POST"]) +async def handle_request(): + try: + with _list_lock: + global request_nums + request_nums += 1 + + def extract_ip_port_fast(url): + match = IP_PORT_PATTERN.search(url) + if not match: + raise ValueError(f"Invalid URL format: {url}") + return match.groups() + + req_data = await request.get_json() + request_id = str(uuid.uuid4()) + + prefill_instance_endpoint = None + decode_instance_endpoint = None + error_msg = ( + "Service Unavailable: No prefill or decode instances are registered." + ) + if not prefill_instances or not decode_instances: + return await make_response( + ( + error_msg, + 503, + ) + ) + pid = request_nums % len(prefill_instances) + did = request_nums % len(decode_instances) + prefill_instance_endpoint = prefill_instances[pid] + decode_instance_endpoint = decode_instances[did] + + selected_prefill_dp_rank = None + if prefill_instance_endpoint["dp_size"] > 1: + selected_prefill_dp_rank = example_round_robin_dp_loader( + request_nums // len(prefill_instance_endpoint), + prefill_instance_endpoint["dp_size"], + ) + + dip, dport = extract_ip_port_fast(decode_instance_endpoint["request_address"]) + ip, port = extract_ip_port_fast(prefill_instance_endpoint["request_address"]) + + req_data_to_prefill = copy.deepcopy(req_data) + req_data_to_prefill["kv_transfer_params"] = {} + req_data["kv_transfer_params"] = {} + req_data_to_prefill["kv_transfer_params"]["remote_dp_size"] = ( + decode_instance_endpoint["dp_size"] + ) + req_data_to_prefill["kv_transfer_params"]["remote_tp_size"] = ( + decode_instance_endpoint["tp_size"] + ) + + send_prefill_task = asyncio.create_task( + send_request_to_prefill( + prefill_instance_endpoint["request_address"], + req_data_to_prefill, + request_id, + decode_instance_endpoint, + dip, + dport, + selected_prefill_dp_rank, + ) + ) + ip, port = extract_ip_port_fast(prefill_instance_endpoint["request_address"]) + + req_data["max_tokens"] -= 1 + + req_data["kv_transfer_params"] = { + "do_remote_decode": False, + "do_remote_prefill": True, + "remote_handshake_port": prefill_instance_endpoint["handshake_port"], + "remote_notify_port": prefill_instance_endpoint["notify_port"], + "remote_engine_id": None, + "remote_block_ids": None, + "remote_host": ip, + "remote_port": port, + } + if TRANSFER_TYPE == "READ": + # In read mode, prefill and decode are executed serially. + prefill_response = await send_prefill_task + req_data["kv_transfer_params"]["remote_engine_id"] = prefill_response[ + "kv_transfer_params" + ]["remote_engine_id"] + req_data["kv_transfer_params"]["remote_block_ids"] = prefill_response[ + "kv_transfer_params" + ]["remote_block_ids"] + + req_data["kv_transfer_params"]["remote_dp_size"] = prefill_instance_endpoint[ + "dp_size" + ] + req_data["kv_transfer_params"]["remote_tp_size"] = prefill_instance_endpoint[ + "tp_size" + ] + + if selected_prefill_dp_rank is not None: + req_data["kv_transfer_params"]["remote_dp_rank"] = selected_prefill_dp_rank + + decode_request_task = asyncio.create_task( + start_decode_request( + decode_instance_endpoint["request_address"], req_data, request_id + ) + ) + + session, decode_response = await decode_request_task + stream_generator = stream_decode_response(session, decode_response, request_id) + response = await make_response(stream_generator) + return response + except Exception as e: + logger.exception("An error occurred while handling the request: %s", e) + return await make_response( + ( + f"Internal Server Error: {e!s}", + 500, + ) + ) + + +if __name__ == "__main__": + t = start_service_discovery("0.0.0.0", 36367) + app.debug = True + app.config["BODY_TIMEOUT"] = 360000 + app.config["RESPONSE_TIMEOUT"] = 360000 + + app.run(host="0.0.0.0", port=10001) + t.join() diff --git a/tests/v1/kv_connector/unit/test_moriio_connector.py b/tests/v1/kv_connector/unit/test_moriio_connector.py new file mode 100644 index 0000000000000..1cc6988635d8d --- /dev/null +++ b/tests/v1/kv_connector/unit/test_moriio_connector.py @@ -0,0 +1,545 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import importlib.util +import os +from unittest.mock import MagicMock, patch + +import msgspec +import pytest +import torch +import zmq + +from tests.conftest import _find_free_port +from vllm.config import ( + CacheConfig, + DeviceConfig, + KVTransferConfig, + ModelConfig, + SchedulerConfig, + VllmConfig, +) +from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import ( + MoRIIOAgentMetadata, + MoRIIOConnectorMetadata, + MoRIIOConstants, + zmq_ctx, +) +from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector import ( + KVConnectorRole, + MoRIIOConnector, + MoRIIOConnectorWorker, +) +from vllm.platforms import current_platform +from vllm.utils.network_utils import ( + get_ip, + make_zmq_path, +) + +from .utils import create_request, create_scheduler + +aiter_available = importlib.util.find_spec("aiter") is not None +mori_available = importlib.util.find_spec("mori") is not None +pytestmark = pytest.mark.skipif( + not (current_platform.is_rocm() and mori_available), + reason="MoRIIOs are only available on ROCm with aiter package installed", +) + + +@pytest.fixture +def mock_parallel_groups(): + """Mock tensor/data parallel group functions for single-rank tests.""" + mock_group = MagicMock() + mock_group.rank = 0 + mock_group.local_rank = 0 + mock_group.world_size = 1 + + with ( + patch.multiple( + "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common", + get_tensor_model_parallel_rank=MagicMock(return_value=0), + get_tensor_model_parallel_world_size=MagicMock(return_value=0), + ), + patch.multiple( + "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector", + get_tensor_model_parallel_world_size=MagicMock(return_value=0), + get_world_group=MagicMock(return_value=mock_group), + get_tp_group=MagicMock(return_value=mock_group), + ), + ): + yield mock_group + + +def _setup_kv_transfer_request(request, remote_host="127.0.0.1", fake_port=4789): + """Setup KV transfer parameters for a request.""" + request.kv_transfer_params.update( + { + "remote_notify_port": fake_port, + "remote_block_ids": None, + "remote_host": remote_host, + "remote_port": fake_port, + "remote_handshake_port": fake_port, + "remote_engine_id": "test_engine", + } + ) + return request + + +class FakeMorIIOWrapper: + # A fake MoRIIOWrapper for testing purposes + def __init__(self, *args, **kwargs): + pass + + def set_moriio_engine(self, moriio_engine): + pass + + def set_backend_type(self, backend_type): + pass + + def get_agent_metadata(self): + pass + + def register_remote_engine(self, remote_packed_engine_metadata): + pass + + def register_local_tensor(self, tensor: torch.Tensor): + pass + + def get_unpack_memory_metadata(self, packed_memory_metadata): + pass + + def build_session(self, local_memory_metadata, remote_memory_metadata): + pass + + def read_remote_data( + self, transfer_size_byte, local_offset=0, remote_offset=0, session=None + ): + pass + + def write_remote_data( + self, transfer_size_byte, local_offset=0, remote_offset=0, session=None + ): + pass + + def write_remote_data_single( + self, transfer_size_byte, local_offset=0, remote_offset=0, sess_idx=0 + ): + pass + + def waiting_for_transfer_complete(self): + pass + + def async_wait_reqid(self): + pass + + def _handle_message(self, msg: bytes): + pass + + def _handle_structured_message(self, data: dict): + pass + + def _handle_completion_message(self, msg: str): + pass + + def send_notify(self, req_ids, remote_ip, remote_port): + pass + + def pop_finished_req_ids(self): + pass + + def pop_finished_write_req_ids(self): + pass + + def shutdown(self): + pass + + +class FakeMorIIOConnectorWorker(MoRIIOConnectorWorker): + # Define a fake remote engine id for testing + REMOTE_ENGINE_ID = "remote_engine" + + def __init__( + self, *args, hand_shake_latency: float = 1.8, kv_cache_layout="HND", **kwargs + ): + super().__init__(*args, **kwargs) + + +def create_vllm_config( + model: str = "facebook/opt-125m", + max_num_seqs: int = 16, + max_num_batched_tokens: int = 64, + block_size: int = 16, + max_model_len: int = 10000, + enable_chunked_prefill: bool = True, + enable_permute_local_kv: bool = False, + role="kv_consumer", +) -> VllmConfig: + """Initialize VllmConfig for testing.""" + scheduler_config = SchedulerConfig( + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + max_model_len=max_model_len, + enable_chunked_prefill=enable_chunked_prefill, + is_encoder_decoder=False, + ) + model_config = ModelConfig( + model=model, + trust_remote_code=True, + dtype="bfloat16", + seed=42, + ) + # Cache config, optionally force APC + cache_config = CacheConfig( + block_size=block_size, + gpu_memory_utilization=0.9, + swap_space=0, + cache_dtype="auto", + enable_prefix_caching=True, + ) + kv_transfer_config = KVTransferConfig( + kv_connector="MoRIIOConnector", + kv_role=role, + enable_permute_local_kv=enable_permute_local_kv, + ) + return VllmConfig( + scheduler_config=scheduler_config, + model_config=model_config, + cache_config=cache_config, + kv_transfer_config=kv_transfer_config, + device_config=DeviceConfig("cpu"), + ) + + +@pytest.fixture +def moriio_read_mode(): + """Force the connector into read mode via env for tests.""" + os.environ["VLLM_MORIIO_CONNECTOR_READ_MODE"] = "True" + yield + # Cleanup after test + os.environ.pop("VLLM_MORIIO_CONNECTOR_READ_MODE", None) + + +def test_write_mode_saves_local_block_ids(): + """Write mode records local block ids in MoRIIOConnectorMetadata.reqs_to_save.""" + + # Setup Scheduler and Request + vllm_config = create_vllm_config(role="kv_producer") + scheduler = create_scheduler(vllm_config) + + # 2 Full Blocks and 1 Half Block. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + request = create_request( + request_id=1, + block_size=BLOCK_SIZE, + num_tokens=NUM_TOKENS, + do_remote_decode=True, + do_remote_prefill=False, + ) + request_id = request.request_id + + scheduler.add_request(request) + + # Fake Config + request = _setup_kv_transfer_request(request) + + # Remote Prefill, triggers MoRIIOConnectorMetadata. + scheduler_output = scheduler.schedule() + kv_connector_metadata = scheduler_output.kv_connector_metadata + assert kv_connector_metadata is not None, "kv_connector_metadata is None" + assert isinstance(kv_connector_metadata, MoRIIOConnectorMetadata) + + assert len(kv_connector_metadata.reqs_to_save) == 1, ( + "Unexpected number of reqs_to_save" + ) + assert len(kv_connector_metadata.reqs_to_recv) == 0, ( + "Unexpected number of reqs_to_recv" + ) + assert len(kv_connector_metadata.reqs_to_send) == 0, ( + "Unexpected number of reqs_to_send" + ) + assert request_id in kv_connector_metadata.reqs_to_save, ( + "Request ID not in reqs_to_save" + ) + req_meta = kv_connector_metadata.reqs_to_save[request_id] + + for block_id, block in zip( + req_meta.local_block_ids, + scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[ + request_id + ], + ): + assert block_id == block.block_id, f"{block_id} != {block.block_id}" + + +def test_write_mode_with_chunked_prefill_saves_local_block_ids(): + """Write mode with chunked prefill still records correct local block ids.""" + # Setup Scheduler and Request + MAX_NUM_BATCHED_TOKENS = 64 + NUM_TOKENS = MAX_NUM_BATCHED_TOKENS * 2 + MAX_NUM_BATCHED_TOKENS // 2 + + vllm_config = create_vllm_config( + max_num_batched_tokens=MAX_NUM_BATCHED_TOKENS, role="kv_producer" + ) + BLOCK_SIZE = vllm_config.cache_config.block_size + + scheduler = create_scheduler(vllm_config) + + # 2 Full Blocks and 1 Half Block. + + request = create_request( + request_id=1, + block_size=BLOCK_SIZE, + num_tokens=NUM_TOKENS, + do_remote_decode=True, + do_remote_prefill=False, + ) + request_id = request.request_id + + scheduler.add_request(request) + + # Fake Config + request = _setup_kv_transfer_request(request) + + # Remote Prefill with chunked prefill, triggers multiple schedules. + expected_counts = [(0, 0, 0), (0, 0, 0), (1, 0, 0)] + kv_connector_metadata = None + for _, (expected_save, expected_recv, expected_send) in enumerate(expected_counts): + scheduler_output = scheduler.schedule() + kv_connector_metadata = scheduler_output.kv_connector_metadata + + assert len(kv_connector_metadata.reqs_to_save) == expected_save + assert len(kv_connector_metadata.reqs_to_recv) == expected_recv + assert len(kv_connector_metadata.reqs_to_send) == expected_send + assert kv_connector_metadata is not None, "kv_connector_metadata is None" + assert request_id in kv_connector_metadata.reqs_to_save, ( + "Request ID not in reqs_to_save" + ) + req_meta = kv_connector_metadata.reqs_to_save[request_id] + + for block_id, block in zip( + req_meta.local_block_ids, + scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[ + request_id + ], + ): + assert block_id == block.block_id, f"{block_id} != {block.block_id}" + + +def test_read_mode_loads_remote_block_ids(moriio_read_mode): + """Read mode loads remote block ids into local cache mapping.""" + + # Setup Scheduler and Request + vllm_config = create_vllm_config(role="kv_consumer") + scheduler = create_scheduler(vllm_config) + + # 2 Full Blocks and 1 Half Block. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + request = create_request( + request_id=1, + block_size=BLOCK_SIZE, + num_tokens=NUM_TOKENS, + do_remote_decode=False, + do_remote_prefill=True, + ) + request_id = request.request_id + + scheduler.add_request(request) + block_list = scheduler.kv_cache_manager.coordinator.single_type_managers[ + 0 + ].req_to_blocks[request_id] + + request = _setup_kv_transfer_request(request) + + # Set remote block ids to be fetched. + request.kv_transfer_params["remote_block_ids"] = block_list + + # Remote Prefill, triggers MorIIOConnectorMetadata. + + scheduler_output = scheduler.schedule() + kv_connector_metadata = scheduler_output.kv_connector_metadata + assert kv_connector_metadata is not None, "kv_connector_metadata is None" + assert isinstance(kv_connector_metadata, MoRIIOConnectorMetadata), ( + "kv_connector_metadata is not MoRIIOConnectorMetadata" + ) + assert len(kv_connector_metadata.reqs_to_save) == 0, ( + "Unexpected number of reqs_to_save" + ) + assert len(kv_connector_metadata.reqs_to_recv) == 1, ( + "Unexpected number of reqs_to_recv" + ) + assert len(kv_connector_metadata.reqs_to_send) == 0, ( + "Unexpected number of reqs_to_send" + ) + assert request_id in kv_connector_metadata.reqs_to_recv, ( + "Request ID not in reqs_to_recv" + ) + req_meta = kv_connector_metadata.reqs_to_recv[request_id] + + for block_id, block in zip( + req_meta.local_block_ids, + scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[ + request_id + ], + ): + assert block_id == block.block_id, f"{block_id} != {block.block_id}" + + +@pytest.mark.skipif( + not aiter_available, reason="Requires aiter package for ROCm FlashAttention backend" +) +def test_register_kv_caches(mock_parallel_groups): + """Test that MoRIIOConnector.register_kv_caches correctly registers kv caches.""" + ROLE = "kv_consumer" + IP = get_ip() + vllm_config = create_vllm_config(role=ROLE) + DEFAULT_PORT = 6301 + TP_RANK = 0 + DP_RANK = 0 + from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend + + backend_cls = AiterFlashAttentionBackend + + # Create test kv cache tensors using proper backend shape + kv_cache_shape = backend_cls.get_kv_cache_shape( + num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 + ) + shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + kv_caches = { + "layer0": shared_tensor, + "layer1": unique_tensor, + "layer2": shared_tensor, + } + + with ( + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.threading.Event" + ), + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.threading.Thread" + ), + ): + # Create connector + vllm_config.kv_transfer_config.kv_connector_extra_config.update( + { + "proxy_ip": "127.0.0.1", + "proxy_ping_port": 12345, + "http_port": 12346, + } + ) + + connector = MoRIIOConnector(vllm_config, KVConnectorRole.WORKER) + connector.connector_worker = FakeMorIIOConnectorWorker( + vllm_config, connector.engine_id, hand_shake_latency=0 + ) + + from mori.io import ( + MemoryDesc, + ) + + # Execute register_kv_caches + connector.register_kv_caches(kv_caches) + + # Verify that the MemoryDesc stored in layer_name_to_local_kv_cache_metadata + assert ( + shared_tensor.data_ptr() + == MemoryDesc.unpack( + connector.connector_worker.layer_name_to_local_kv_cache_metadata[ + "layer0" + ][0] + ).data + ) + assert ( + unique_tensor.data_ptr() + == MemoryDesc.unpack( + connector.connector_worker.layer_name_to_local_kv_cache_metadata[ + "layer1" + ][0] + ).data + ) + assert ( + shared_tensor.data_ptr() + == MemoryDesc.unpack( + connector.connector_worker.layer_name_to_local_kv_cache_metadata[ + "layer2" + ][0] + ).data + ) + + # Verify engine keys + expected_engine_key = f"{ROLE[3:]}:{IP}:{DEFAULT_PORT}:tp{TP_RANK}:dp{DP_RANK}" + assert ( + MemoryDesc.unpack( + connector.connector_worker.layer_name_to_local_kv_cache_metadata[ + "layer0" + ][0] + ).engine_key + == expected_engine_key + ) + + +@pytest.mark.skipif( + not aiter_available, reason="Requires aiter package for ROCm FlashAttention backend" +) +def test_moriio_handshake_returns_metadata(mock_parallel_groups): + """MoRIIO handshake socket returns valid agent metadata over ZMQ.""" + + ROLE = "kv_consumer" + vllm_config = create_vllm_config(role=ROLE) + from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend + + backend_cls = AiterFlashAttentionBackend + + # Create test kv cache tensors using proper backend shape + kv_cache_shape = backend_cls.get_kv_cache_shape( + num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 + ) + shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + kv_caches = { + "layer0": shared_tensor, + "layer1": unique_tensor, + "layer2": shared_tensor, + } + + with ( + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_engine.MoRIIOWrapper", + FakeMorIIOWrapper, + ), + ): + handshake_port = _find_free_port() + # Create connector + vllm_config.kv_transfer_config.kv_connector_extra_config.update( + { + "proxy_ip": "127.0.0.1", + "proxy_ping_port": 12345, + "http_port": 12346, + "handshake_port": handshake_port, + } + ) + connector = MoRIIOConnector(vllm_config, KVConnectorRole.WORKER) + + # Execute register_kv_caches + connector.register_kv_caches(kv_caches) + + # Connect to handshake socket and request metadata + path = make_zmq_path("tcp", "127.0.0.1", handshake_port) + with zmq_ctx(zmq.DEALER, path) as sock: + sock.send(MoRIIOConstants.GET_META_MSG) + received_frame = sock.recv_multipart() + + if len(received_frame) != 2 or received_frame[0] != b"": + raise ValueError(f"Unexpected frame! {received_frame = }") + + metadata_bytes = received_frame[1] + decoder = msgspec.msgpack.Decoder(MoRIIOAgentMetadata) + metadata = decoder.decode(metadata_bytes) + assert isinstance(metadata, MoRIIOAgentMetadata), ( + "Decoded metadata is not MoRIIOAgentMetadata" + ) diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 02d9a1ec9599e..f4113c91b60f6 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -179,6 +179,12 @@ KVConnectorFactory.register_connector( "MultiConnector", ) +KVConnectorFactory.register_connector( + "MoRIIOConnector", + "vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector", + "MoRIIOConnector", +) + KVConnectorFactory.register_connector( "OffloadingConnector", "vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector", diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/__init__.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_common.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_common.py new file mode 100644 index 0000000000000..026e7faf57616 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_common.py @@ -0,0 +1,321 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib +import threading +import time +from collections.abc import Iterator +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional + +import msgspec +import torch +import zmq + +from vllm import envs +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorMetadata, +) +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.logger import init_logger +from vllm.utils.network_utils import ( + get_ip, + get_open_port, + make_zmq_socket, +) + +if TYPE_CHECKING: + pass + +from dataclasses import field +from enum import Enum + +logger = init_logger(__name__) + + +Transfer = tuple[int, float] +EngineId = str +ReqId = str + + +@dataclass +class WriteTask: + request_id: str + dst_engine_id: str + local_block_ids: list[int] + remote_block_ids_hint: list[int] | None + layer_name: str + event: torch.cuda.Event + remote_notify_port: int + remote_ip: str + enqueue_time: float = field(default_factory=time.perf_counter) + retried: int = 0 + + +@dataclass +class LayerTransferPlan: + """Plan for transferring a single layer.""" + + request_id: str + layer_name: str + sess_idx: int + transfer_local_offsets: list[int] + transfer_remote_offsets: list[int] + transfer_sizes: list[int] + use_batch: bool = True + + +@dataclass +class RemoteAllocInfo: + """Information about remote block allocation.""" + + block_ids: list[int] + writes_done: int = 0 + decode_dp_rank: int = 0 + transfer_offset: tuple[list[int], list[int], list[int]] | None = None + + +class ROLE(Enum): + PRODUCER = "producer" + CONSUMER = "consumer" + NOTINIT = "notinit" + + +class MoRIIOAgentMetadata( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property.d + dict=True, +): + engine_id: str + agent_metadata: bytes + kv_caches_base_addr: list[int] + num_blocks: int + block_len: int + attn_backend_name: str + + +class RoleManager: + """Manages role state across the connector.""" + + _instance: Optional["RoleManager"] = None + _lock = threading.Lock() + + def __init__(self) -> None: + self._role: ROLE = ROLE.NOTINIT + + @classmethod + def get_instance(cls) -> "RoleManager": + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def set_role(self, role: ROLE) -> None: + """Set the current role.""" + with self._lock: + self._role = role + + def get_role(self) -> ROLE: + """Get the current role.""" + return self._role + + +def set_role(role: ROLE): + """Set the global role.""" + RoleManager.get_instance().set_role(role) + + +def get_role() -> ROLE: + """Get the global role.""" + return RoleManager.get_instance().get_role() + + +class MoRIIOMode(Enum): + READ = "read" + WRITE = "write" + + +class MoRIIOError(Exception): + """Base exception for MoRIIO operations.""" + + pass + + +class HandshakeError(MoRIIOError): + """Exception raised when handshake fails.""" + + pass + + +class TransferError(MoRIIOError): + """Exception raised when transfer fails.""" + + pass + + +def get_moriio_mode() -> MoRIIOMode: + read_mode = envs.VLLM_MORIIO_CONNECTOR_READ_MODE + logger.debug("MoRIIO Connector read_mode: %s", read_mode) + if read_mode: + return MoRIIOMode.READ + else: + return MoRIIOMode.WRITE + + +def get_port_offset(dp_rank: int, tp_rank: int, tp_size: int = 1) -> int: + return (dp_rank) * tp_size + tp_rank + + +@dataclass +class MoRIIOConfig: + local_ip: str + local_kv_port: int + proxy_ip: str + local_ping_port: int + proxy_ping_port: int + http_port: int + handshake_port: int + notify_port: int + tp_rank: int + dp_rank: int + dp_size: int + tp_size: int + + @classmethod + def from_vllm_config(cls, vllm_config: VllmConfig) -> "MoRIIOConfig": + # Port Configuration: + # local_ping_port -> Outgoing heartbeat to proxy + # proxy_ping_port -> Remote proxy's heartbeat ingress port + # http_port -> Instance's HTTP service endpoint + # local_kv_port -> service port for mori engine + # notify_port -> For synchronizing stages between prefill and decode + # handshake_port -> For initial handshake between mori engine + + # TODO : merge notify_port and handshake_port to simplify port management + # supports non-contiguous ports + assert vllm_config.kv_transfer_config is not None, ( + "kv_transfer_config must be set for MoRIIOConnector" + ) + kv_transfer_config = vllm_config.kv_transfer_config + extra_config = kv_transfer_config.kv_connector_extra_config + tp_rank = get_tensor_model_parallel_rank() + dp_rank = vllm_config.parallel_config.data_parallel_rank + base_notify_port = int(extra_config["notify_port"]) + dp_size = vllm_config.parallel_config.data_parallel_size + tp_size = get_tensor_model_parallel_world_size() + port_offset = get_port_offset(dp_rank, tp_rank) + + return cls( + local_ip=get_ip(), + local_kv_port=get_open_port(), + proxy_ip=extra_config["proxy_ip"], + local_ping_port=get_open_port(), + proxy_ping_port=int(extra_config["proxy_ping_port"]), + http_port=int(extra_config["http_port"]), + handshake_port=int(extra_config["handshake_port"]), + notify_port=base_notify_port + port_offset, + tp_rank=tp_rank, + dp_rank=dp_rank, + dp_size=dp_size, + tp_size=tp_size, + ) + + +class MoRIIOConstants: + """Constants for MoRIIO connector.""" + + # ZMQ message types + GET_META_MSG = b"get_meta_msg" + POP_DONE_RECV = b"pop_done_recv" + OVER = b"OVER" + COMPLETION_PREFIX = "cmpl" + + PING_INTERVAL = 5 + MAX_PING_RETRIES = 100 + DEFAULT_HANDSHAKE_PORT = "6301" + DEFAULT_NOTIFY_PORT = "61005" + + VLLM_MORI_READ_ABORT_REQUEST_TIMEOUT = 3600 + + +@dataclass +class ReqMeta: + """Metadata for a single request.""" + + local_block_ids: list[int] + remote_block_ids: list[int] + remote_host: str + remote_port: int + remote_handshake_port: int + remote_notify_port: int + remote_engine_id: str + tp_size: int + remote_dp_size: int + + +class MoRIIOConnectorMetadata(KVConnectorMetadata): + def __init__(self): + self.reqs_to_recv: dict[ReqId, ReqMeta] = {} + self.reqs_to_save: dict[ReqId, ReqMeta] = {} + self.reqs_to_send: dict[ReqId, float] = {} + + def __repr__(self): + return_str = "" + for req_id, req_meta in self.reqs_to_recv.items(): + return_str += ( + f"{req_id = },{req_meta.local_block_ids = }," + f"{req_meta.remote_host = },{req_meta.remote_port = }" + f"{req_meta.remote_engine_id = },{req_meta.tp_size = }" + ) + return_str = f"MoRIIOConnectorMetadata:reqs_to_recv:{return_str}," + + for req_id, expiry in self.reqs_to_send.items(): + return_str += f"{req_id = },{expiry = }" + return_str = f"MoRIIOConnectorMetadata:reqs_to_send:{return_str}," + return return_str + + def add_new_req( + self, + request_id: ReqId, + local_block_ids: list[int], + kv_transfer_params: dict[str, Any], + write_mode=False, + ): + _req = ReqMeta( + local_block_ids=local_block_ids, + remote_block_ids=kv_transfer_params["remote_block_ids"], + remote_engine_id=kv_transfer_params["remote_engine_id"], + remote_host=kv_transfer_params["remote_host"], + remote_port=kv_transfer_params["remote_port"], + remote_handshake_port=kv_transfer_params["remote_handshake_port"], + remote_notify_port=kv_transfer_params["remote_notify_port"], + tp_size=kv_transfer_params.get("tp_size", 1), + remote_dp_size=kv_transfer_params.get("remote_dp_size", 1), + ) + if write_mode: + self.reqs_to_save[request_id] = _req + else: + self.reqs_to_recv[request_id] = _req + + +@contextlib.contextmanager +def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: + """Context manager for a ZMQ socket""" + + if socket_type not in (zmq.ROUTER, zmq.REQ, zmq.DEALER): + raise ValueError(f"Unexpected socket type: {socket_type}") + + ctx: zmq.Context | None = None + try: + ctx = zmq.Context() # type: ignore[attr-defined] + yield make_zmq_socket( + ctx=ctx, path=addr, socket_type=socket_type, bind=socket_type == zmq.ROUTER + ) + finally: + if ctx is not None: + ctx.destroy(linger=0) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py new file mode 100644 index 0000000000000..4b6bd906d5d44 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py @@ -0,0 +1,1515 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import logging +import math +import queue +import threading +import time +from collections import defaultdict +from concurrent.futures import Future, ThreadPoolExecutor +from typing import TYPE_CHECKING, Any, Optional + +import msgpack +import msgspec +import numpy as np +import torch +import zmq + +from vllm.attention.selector import get_attn_backend +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, + KVConnectorMetadata, + KVConnectorRole, +) +from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import ( + ROLE, + EngineId, + HandshakeError, + MoRIIOAgentMetadata, + MoRIIOConfig, + MoRIIOConnectorMetadata, + MoRIIOConstants, + MoRIIOMode, + ReqId, + ReqMeta, + WriteTask, + get_moriio_mode, + get_port_offset, + get_role, + set_role, + zmq_ctx, +) +from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_engine import ( + MoRIIOWrapper, + MoRIIOWriter, +) +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_world_size, + get_tp_group, + get_world_group, +) +from vllm.forward_context import ForwardContext +from vllm.logger import init_logger +from vllm.utils.network_utils import ( + get_ip, + make_zmq_path, + make_zmq_socket, +) +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.request import RequestStatus + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.kv_cache_interface import KVCacheConfig + from vllm.v1.request import Request + +logger = init_logger(__name__) + +try: + from mori.io import ( + BackendType, + IOEngine, + IOEngineConfig, + ) + + logger.info("MoRIIO is available") + MoRIIO_enabled = True +except ImportError: + logger.error("MoRIIO is not available") + MoRIIO_enabled = False + + +def is_moriio_available() -> bool: + return MoRIIO_enabled + + +class MoRIIOConnector(KVConnectorBase_V1): + def __init__( + self, + vllm_config: VllmConfig, + role: KVConnectorRole, + kv_cache_config: Optional["KVCacheConfig"] = None, + ): + super().__init__(vllm_config, role) + assert vllm_config.kv_transfer_config is not None, ( + "kv_transfer_config must be set for MoRIIOConnector" + ) + + self.kv_transfer_config = vllm_config.kv_transfer_config + self._set_port_defaults(vllm_config) + + self.engine_id = ( + str(get_ip()) + + ":" + + str(self.kv_transfer_config.kv_connector_extra_config["handshake_port"]) + ) + self.mode = get_moriio_mode() + if role == KVConnectorRole.SCHEDULER: + self.connector_scheduler: MoRIIOConnectorScheduler | None = ( + MoRIIOConnectorScheduler(vllm_config, self.engine_id) + ) + self.connector_worker: MoRIIOConnectorWorker | None = None + elif role == KVConnectorRole.WORKER: + self.connector_scheduler = None + self.connector_worker = MoRIIOConnectorWorker(vllm_config, self.engine_id) + logger.info( + "Initialized MoRIIO Connector,engine_id:%s,role: %s", + self.engine_id, + role.value, + ) + + ############################################################ + # Scheduler Side Methods + ############################################################ + + def _set_port_defaults(self, vllm_config: VllmConfig): + assert vllm_config.kv_transfer_config is not None, ( + "kv_transfer_config must be set for MoRIIOConnector" + ) + kv_transfer_config = vllm_config.kv_transfer_config + extra_config = kv_transfer_config.kv_connector_extra_config + + if "handshake_port" not in extra_config or not extra_config["handshake_port"]: + extra_config["handshake_port"] = MoRIIOConstants.DEFAULT_HANDSHAKE_PORT + + if "notify_port" not in extra_config or not extra_config["notify_port"]: + extra_config["notify_port"] = MoRIIOConstants.DEFAULT_NOTIFY_PORT + + def get_num_new_matched_tokens( + self, request: "Request", num_computed_tokens: int + ) -> tuple[int, bool]: + assert self.connector_scheduler is not None + return self.connector_scheduler.get_num_new_matched_tokens( + request, num_computed_tokens + ) + + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ): + assert self.connector_scheduler is not None + return self.connector_scheduler.update_state_after_alloc( + request, blocks, num_external_tokens, self.connector_worker + ) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + assert self.connector_scheduler is not None + return self.connector_scheduler.build_connector_meta(scheduler_output) + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, dict[str, Any] | None]: + assert self.connector_scheduler is not None + return self.connector_scheduler.request_finished(request, block_ids) + + ############################################################ + # Worker Side Methods + ############################################################ + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + assert self.connector_worker is not None + self.connector_worker.register_kv_caches(kv_caches) + + def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]: + """Get the finished recving and sending requests.""" + assert self.connector_worker is not None + return self.connector_worker.get_finished() + + def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: + assert self.connector_worker is not None + if self.mode == MoRIIOMode.WRITE and get_role() == ROLE.CONSUMER: + self.connector_worker.moriio_wrapper.async_wait_reqid() + + assert isinstance(self._connector_metadata, MoRIIOConnectorMetadata) + self.connector_worker.start_load_kv(self._connector_metadata) + + def wait_for_layer_load(self, layer_name: str) -> None: + pass + + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", + **kwargs, + ) -> None: + # Only producer/prefill saves KV Cache + if get_role() == ROLE.CONSUMER: + return + assert self.connector_worker is not None, ( + "save_kv_layer called on scheduler role" + ) + + assert isinstance(self._connector_metadata, MoRIIOConnectorMetadata), ( + "Connector metadata not initialized yet" + ) + self.connector_worker.save_kv_layer( + self._connector_metadata, layer_name, kv_layer, attn_metadata, **kwargs + ) + + return None + + def wait_for_save(self): + pass + + def shutdown(self): + if self.connector_worker is not None: + self.connector_worker.shutdown() + if self.connector_scheduler is not None: + self.connector_scheduler.shutdown() + + def has_connector_metadata(self) -> bool: + """Check whether the connector metadata is currently set. + + Returns: + bool: True if connector metadata exists, False otherwise. + """ + try: + return self._connector_metadata is not None + except AttributeError: + return False + + +class MoRIIOConnectorScheduler: + """Implementation of Scheduler side methods""" + + def __init__(self, vllm_config: VllmConfig, engine_id: str): + self.vllm_config = vllm_config + + assert vllm_config.kv_transfer_config is not None, ( + "kv_transfer_config must be set for MoRIIOConnector" + ) + self.kv_transfer_config = vllm_config.kv_transfer_config + self.block_size = vllm_config.cache_config.block_size + self.engine_id: EngineId = engine_id + self.mode = get_moriio_mode() + self.host_ip = get_ip() + self.handshake_port = self.kv_transfer_config.kv_connector_extra_config[ + "handshake_port" + ] + logger.info("Initializing MoRIIO Scheduler engine_id = %s", engine_id) + + self.side_notify_port = self.kv_transfer_config.kv_connector_extra_config[ + "notify_port" + ] + self.tp_size = self.vllm_config.parallel_config.tensor_parallel_size + self.dp_rank = self.vllm_config.parallel_config.data_parallel_rank + self.is_producer = self.kv_transfer_config.kv_role == "kv_producer" + # Requests that need to start recv/send. + # New requests are added by update_state_after_alloc in + # the scheduler. Used to make metadata passed to Worker. + self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {} + self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {} + + # For chunked prefill, we perform layer-wise access within the final chunk. + # TODO: Perform transfer at end chunk. + self._reqs_need_pending_save: dict[ReqId, tuple[Request, list[int]]] = {} + + if self.is_producer: + set_role(ROLE.PRODUCER) + else: + set_role(ROLE.CONSUMER) + # Reqs to send and their expiration time + self._reqs_need_send: dict[ReqId, float] = {} + self.paths: dict[str, zmq.Socket] = {} + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int, bool]: + """ + For remote prefill, pull all prompt blocks from remote + asynchronously relative to engine execution. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + Returns: + * the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + * true if the external KV cache tokens will be loaded + asynchronously (between scheduler steps). + """ + if self.is_producer: + return 0, False + + token_ids = request.prompt_token_ids or [] + if self.mode == MoRIIOMode.WRITE: + # MoriiO in write mode, no remote prefill + + return len(token_ids) - num_computed_tokens, True + + return len(token_ids) - 1 - num_computed_tokens, False + + def send_notify_block( + self, req_id: str, block_notify_list: list[int], host=None, port=None + ): + path = make_zmq_path("tcp", host, port) + if path not in self.paths: + ctx = zmq.Context.instance() + sock = make_zmq_socket( + ctx=ctx, path=path, socket_type=zmq.DEALER, bind=False + ) + self.paths[path] = sock + + data = { + "req_id": req_id, + "block_notify_list": block_notify_list or [], + "decode_rank": self.dp_rank, + "type": "remote_blocks", + } + serialized_data = msgpack.dumps(data) + self.paths[path].send(serialized_data) + + def update_state_after_alloc( + self, + request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int, + connector_worker: Optional["MoRIIOConnectorWorker"] = None, + ): + params = request.kv_transfer_params + if not params: + return + if params.get("do_remote_decode"): + local_block_ids = blocks.get_block_ids()[0] + self._reqs_need_save[request.request_id] = (request, local_block_ids) + + if params is not None and params.get("do_remote_prefill"): + if self.mode == MoRIIOMode.READ: + if remote_block_ids := params.get("remote_block_ids"): + if all( + p in params + for p in ("remote_engine_id", "remote_host", "remote_port") + ): + # If remote_blocks and num_external_tokens = 0, we + # a full prefix cache hit on the D worker. We need to call + # send_notif in _read_blocks to free the memory on the P. + + # Get unhashed blocks to pull from remote. + local_block_ids = blocks.get_block_ids()[0] + assert len(local_block_ids) <= len(remote_block_ids) + if len(local_block_ids) == len(remote_block_ids): + pass + else: + local_block_ids = remote_block_ids[-len(local_block_ids) :] + + self._reqs_need_recv[request.request_id] = ( + request, + local_block_ids, + ) + else: + logger.warning( + "Got invalid KVTransferParams: %s. This " + "request will not utilize KVTransfer", + params, + ) + + else: + assert request.kv_transfer_params is not None, ( + "kv_transfer_params should not be None" + ) + + remote_dp_rank = request.kv_transfer_params.get("remote_dp_rank", 0) + + for tp_index in range(self.tp_size): + target_port = request.kv_transfer_params[ + "remote_notify_port" + ] + get_port_offset(remote_dp_rank, tp_index) + + self.send_notify_block( + req_id=request.request_id, + block_notify_list=blocks.get_block_ids()[0], + host=params.get("remote_host"), + port=target_port, + ) + + # Only trigger 1 KV transfer per request. + + params["do_remote_prefill"] = False + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + meta = MoRIIOConnectorMetadata() + + if self.mode == MoRIIOMode.WRITE: + # when async_load_kv finished, + # new reqs will be added to scheduler_output.scheduled_new_reqs + + if get_role() == ROLE.CONSUMER: + for new_req in scheduler_output.scheduled_new_reqs: + red_id = new_req.req_id + local_block_ids = list(new_req.block_ids)[0] + assert new_req.sampling_params is not None, ( + f"sampling_params is None for req {new_req.req_id}" + ) + assert hasattr(new_req.sampling_params, "extra_args"), ( + f"sampling_params missing extra_args for req {new_req.req_id}" + ) + kv_transfer_params = ( + new_req.sampling_params.extra_args.get("kv_transfer_params", {}) + if new_req.sampling_params.extra_args + else {} + ) + meta.add_new_req( + red_id, + local_block_ids, + kv_transfer_params, + ) + if get_role() == ROLE.PRODUCER: + # This is the logic for checking against chunked prefill. + # When the last chunk is identified, + # It places the request metadata into the saving queue. + + for i, req_id in enumerate( + scheduler_output.scheduled_cached_reqs.req_ids + ): + new_block_ids = ( + scheduler_output.scheduled_cached_reqs.new_block_ids[i] + ) + + if new_block_ids is not None: + block_ids = new_block_ids[0] + # TODO : hybrid attn, etc + req, existing_blocks = self._reqs_need_pending_save[req_id] + updated_blocks = list(existing_blocks) + (block_ids) + self._reqs_need_pending_save[req_id] = (req, updated_blocks) + if ( + len(self._reqs_need_pending_save[req_id][1]) + * self.block_size + >= req.num_prompt_tokens + ): + meta.add_new_req( + request_id=req_id, + local_block_ids=self._reqs_need_pending_save[req_id][1], + kv_transfer_params=req.kv_transfer_params or {}, + write_mode=True, + ) + del self._reqs_need_pending_save[req_id] + + # Loop through scheduled reqs and convert to ReqMeta. + for req_id, (req, block_ids) in self._reqs_need_recv.items(): + assert req.kv_transfer_params is not None + meta.add_new_req( + request_id=req_id, + local_block_ids=block_ids, + kv_transfer_params=req.kv_transfer_params, + ) + + for req_id, (req, block_ids) in self._reqs_need_save.items(): + assert req.kv_transfer_params is not None + if req.num_prompt_tokens > len(block_ids) * self.block_size: + # not last chunk prefill + self._reqs_need_pending_save[req_id] = (req, block_ids) + continue + meta.add_new_req( + request_id=req_id, + local_block_ids=block_ids, + kv_transfer_params=req.kv_transfer_params, + write_mode=True, + ) + # Clear the list once workers start the transfers + + meta.reqs_to_send = self._reqs_need_send + + self._reqs_need_recv.clear() + self._reqs_need_save.clear() + self._reqs_need_send = {} + + return meta + + def shutdown(self): + for path, sock in self.paths.items(): + try: + sock.close(linger=0) + logger.debug("Closed ZMQ socket for path: %s", path) + except Exception as e: + logger.warning("Error closing ZMQ socket for path %s: %s", path, e) + self.paths.clear() + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, dict[str, Any] | None]: + """ + Once a request is finished, determine whether request blocks + should be freed now or will be sent asynchronously and freed later. + """ + + params = request.kv_transfer_params + logger.debug( + "MoriioConnector request_finished, request_status=%s, " + "kv_transfer_params=%s", + request.status, + params, + ) + if not params: + return False, None + + if params.get("do_remote_prefill"): + # If do_remote_prefill is still True when the request is finished, + # update_state_after_alloc must not have been called (the request + # must have been aborted before it was scheduled). + # To avoid stranding the prefill blocks in the prefill instance, + # we must add empty block_ids to _reqs_need_recv so that our + # worker side will notify and free blocks in the prefill instance. + self._reqs_need_recv[request.request_id] = (request, []) + params["do_remote_prefill"] = False + return False, None + + if ( + not params.get("do_remote_decode") + or request.status != RequestStatus.FINISHED_LENGTH_CAPPED + ): + return False, None + + # computed_block_ids = block_ids if all_full else block_ids[:-1] + computed_block_ids = block_ids + # If prompt < block_size, no xfer so free blocks immediately. + delay_free_blocks = len(computed_block_ids) > 0 + + if delay_free_blocks: + # Prefill request on remote. It will be read from D upon completion + self._reqs_need_send[request.request_id] = ( + time.perf_counter() + + MoRIIOConstants.VLLM_MORI_READ_ABORT_REQUEST_TIMEOUT + ) + + # If we execute in P-D serial mode, no notification port is needed. + return delay_free_blocks, dict( + do_remote_prefill=True, + do_remote_decode=False, + remote_block_ids=computed_block_ids, + remote_engine_id=self.engine_id, + remote_host=self.host_ip, + remote_port=self.handshake_port, + tp_size=self.vllm_config.parallel_config.tensor_parallel_size, + ) + + +class MoRIIOConnectorWorker: + """Implementation of Worker side methods""" + + def __init__(self, vllm_config: VllmConfig, engine_id: str): + if not is_moriio_available(): + raise RuntimeError( + "MoRIIO is not available. Please ensure the 'mori' package " + "is installed and properly configured." + ) + + self.moriio_config = MoRIIOConfig.from_vllm_config(vllm_config) + self.mode = get_moriio_mode() + + logger.info("Initializing MoRIIO worker %s", engine_id) + + logging.getLogger("aiter").disabled = True + + # Config. + self.vllm_config = vllm_config + assert vllm_config.kv_transfer_config is not None, ( + "kv_transfer_config must be set for MoRIIOConnector" + ) + self.kv_transfer_config = vllm_config.kv_transfer_config + self.is_producer = self.kv_transfer_config.is_kv_producer + + if self.is_producer: + set_role(ROLE.PRODUCER) + else: + set_role(ROLE.CONSUMER) + # mori engine + self._rank = get_world_group().rank + self._local_rank = get_world_group().local_rank + self.tp_rank = self.moriio_config.tp_rank + self.dp_rank = self.moriio_config.dp_rank + + self.local_ip = self.moriio_config.local_ip + self.local_kv_port = self.moriio_config.local_kv_port + self.proxy_ip = self.moriio_config.proxy_ip + self.local_ping_port = self.moriio_config.local_ping_port + self.proxy_ping_port = self.moriio_config.proxy_ping_port + self.http_port = self.moriio_config.http_port + self.handshake_port = self.moriio_config.handshake_port + self.notify_port = self.moriio_config.notify_port + + self.zmq_context = zmq.Context() + self.metadata_address = ( + f"{self.moriio_config.local_ip}:{self.moriio_config.local_ping_port}" + ) + self.request_address = ( + f"{self.moriio_config.local_ip}:{self.moriio_config.http_port}" + ) + + self.moriio_engine = None + self._handle_request_thread = None + self._ping_thread = None + self._writer = MoRIIOWriter(self) + + role = "producer" if self.is_producer else "consumer" + engine_suffix = ( + f"{self.moriio_config.local_ip}:{self.moriio_config.handshake_port}:" + f"tp{self.tp_rank}:dp{self.dp_rank}" + ) + self.moriio_engine = IOEngine( + f"{role}:{engine_suffix}", + IOEngineConfig( + self.moriio_config.local_ip, self.moriio_config.local_kv_port + ), + ) + logger.debug( + "build MORI IOEngine %s (ip=%s port=%s)", + f"{role}:{engine_suffix}", + self.moriio_config.local_ip, + self.moriio_config.local_kv_port, + ) + + if self._rank == 0 and self.moriio_config.proxy_ip: + self._ping_thread = threading.Thread( + target=self._ping, args=(self.zmq_context,), daemon=True + ) + self._ping_thread.start() + + logger.info( + "Initializing MoRIIO Engine, engine = %s, role = %s", + self.moriio_engine, + "producer" if self.is_producer else "consumer", + ) + + # Agent. + self.moriio_wrapper = MoRIIOWrapper(tp_rank=self.tp_rank, dp_rank=self.dp_rank) + self.moriio_wrapper.set_moriio_engine(self.moriio_engine) + self.moriio_wrapper.set_backend_type(BackendType.RDMA) + self.moriio_wrapper.notify_port = self.moriio_config.notify_port + self.local_kv_cache_metadata: list[bytes] = [] + self.local_kv_cache_size: list[int] = [] + self.layer_name_to_local_kv_cache_metadata: dict[str, list[bytes]] = {} + + self.remote_kv_cache_metadata: list[bytes] = [] + self.remote_kv_cache_size: list[int] = [] + self.layer_name_to_remote_kv_cache_metadata: dict[str, dict[str, list[Any]]] = ( + dict() + ) + self.remote_moriio_metadata: dict[EngineId, MoRIIOAgentMetadata] = {} + self.slot_size_bytes = 0 + + self.load_ready_flag: dict[str, bool] = {} + self.write_ready_flags: dict[str, bool] = {} + self.kv_cache_shape = None + self.block_shape = None + self.kv_element_size = 0 + + # Map of engine_id -> {agent_name0, agent_name1..}. + self._remote_agents: dict[EngineId, set[str]] = {} + + self.side_channel_port: int = ( + self.moriio_config.handshake_port + + get_port_offset(self.dp_rank, self.tp_rank) + ) + self.engine_id: EngineId = engine_id + + self.world_size = get_tensor_model_parallel_world_size() + self.tp_group = get_tp_group() + + # KV Caches and moriio tracking data. + self.kv_caches: dict[str, torch.Tensor] = {} + + # Map of engine_id -> kv_caches_base_addr. For TP case, each local + # rank will still only pull from a single remote TP worker. + self.kv_caches_base_addr: dict[EngineId, list[int]] = {} + + # Number of MoRIIO regions. Currently one region per cache + # (so 1 per layer for MLA, otherwise 2 per layer) + self.num_regions = 0 + self.num_layers = 0 + + # Map of engine_id -> num_blocks. All ranks in the same deployment will + # have the same number of blocks. + self.dst_num_blocks: dict[EngineId, int] = {} + # In progress transfers. + self._recving_transfers: defaultdict[ReqId, list] = defaultdict(list) + self._recving_transfers_callback_addr: dict[ReqId, tuple[str, str]] = {} + + # Track the expiration time of requests that are waiting to be sent. + self._reqs_to_send: dict[ReqId, float] = {} + + # Background thread for handling new handshake requests. + self._moriio_handshake_listener_t: threading.Thread | None = None + # Background thread for initializing new MoRIIO handshakes. + self._handshake_initiation_executor = ThreadPoolExecutor( + # MoRIIO is not guaranteed to be thread-safe, limit 1 worker. + max_workers=1, + thread_name_prefix="vllm-moriio-handshake-initiator", + ) + self._ready_requests = queue.Queue[tuple[ReqId, ReqMeta]]() + self._handshake_futures: dict[EngineId, Future[set[str]]] = {} + # Protects _handshake_futures and _remote_agents. + self._handshake_lock = threading.RLock() + + self.block_size = vllm_config.cache_config.block_size + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + + self.block_window_per_layer: list[int | None] = [] + self.use_mla = self.model_config.use_mla + self.built_session = False + self.built_write_session: defaultdict[str, list] = defaultdict(list) + backend = get_attn_backend( + self.model_config.get_head_size(), + self.model_config.dtype, + self.cache_config.cache_dtype, + self.block_size, + use_mla=self.use_mla, + ) + + # TODO: consider the integration of flashinfer or other backends. + self.backend_name = backend.get_name() + logger.debug("Detected attention backend %s", self.backend_name) + + def schedule_write_blocks( + self, + request_id: str, + dst_engine_id: str, + local_block_ids: list[int], + remote_block_ids: list[int] | None, + layer_name: str, + kv_layer: torch.Tensor, + remote_notify_port: int, + remote_ip: str, + ) -> None: + """Schedule a block write operation. + + Args: + request_id: Unique identifier for the request + dst_engine_id: Destination engine ID + local_block_ids: Local block IDs to transfer + remote_block_ids: Hint for remote block IDs + layer_name: Name of the layer + kv_layer: KV cache tensor + remote_notify_port: Port for completion notification + remote_ip: IP address of remote node + """ + + # synchronization to prevent dirty reads between + # transfer and attention operations + # we can consider removing this synchronization after ibgda is enabled. + # when mori-io supports ibgda functionality + + stream = torch.cuda.current_stream() + event = torch.cuda.Event() + event.record(stream) + + task = WriteTask( + request_id=request_id, + dst_engine_id=dst_engine_id, + local_block_ids=local_block_ids, + remote_block_ids_hint=remote_block_ids, + layer_name=layer_name, + event=event, + remote_notify_port=remote_notify_port, + remote_ip=remote_ip, + ) + self._writer.schedule_write(task) + + def _get_built_session(self, remote_engine_id): + if remote_engine_id not in self.built_write_session: + cur_remote_engine_sessions = [] + for ln, local_meta in self.layer_name_to_local_kv_cache_metadata.items(): + unpacked_local_memory_meta = ( + self.moriio_wrapper.get_unpack_memory_metadata(local_meta[0]) + ) + unpacked_remote_memory_meta = ( + self.moriio_wrapper.get_unpack_memory_metadata( + self.layer_name_to_remote_kv_cache_metadata[remote_engine_id][ + ln + ][0] + ) + ) + cur_remote_engine_sessions.append( + self.moriio_wrapper.build_session( + unpacked_local_memory_meta, unpacked_remote_memory_meta + ) + ) + self.built_write_session[remote_engine_id] = cur_remote_engine_sessions + return self.built_write_session[remote_engine_id], self.remote_moriio_metadata[ + remote_engine_id + ] + + def _ping(self, zmq_context): + http_request_address = f"http://{self.request_address}/v1/completions" + role = "P" if self.is_producer else "D" + + retry_count = 0 + index = 1 + with zmq_context.socket(zmq.DEALER) as sock: + sock.connect(f"tcp://{self.proxy_ip}:{self.proxy_ping_port}") + + while True: + try: + data = { + "type": "register", + "role": role, + "index": str(index), + "request_address": http_request_address, + "handshake_port": self.handshake_port, + "notify_port": self.notify_port, + "dp_size": self.moriio_config.dp_size, + "tp_size": self.moriio_config.tp_size, + "transfer_mode": self.mode.name, + } + + sock.send(msgpack.dumps(data)) + # logger.debug(f"Successfully sent ping message #{index}") + retry_count = 0 + + except ConnectionRefusedError: + logger.info( + "Connection refused: %s:%s -> %s:%s", + self.local_ip, + self.local_ping_port, + self.proxy_ip, + self.proxy_ping_port, + ) + retry_count += 1 + + except OSError as e: + logger.info("OS error when sending ping: %s", e) + retry_count += 1 + + except Exception as e: + logger.info("Unexpected error when sending ping: %s", e) + retry_count += 1 + if retry_count >= MoRIIOConstants.MAX_PING_RETRIES: + logger.error( + "Max retries (%s) exceeded. Stopping ping loop.", + MoRIIOConstants.MAX_PING_RETRIES, + ) + raise RuntimeError( + f"Ping failed after {retry_count} retries" + ) from e + + finally: + time.sleep(MoRIIOConstants.PING_INTERVAL) + index += 1 + + def shutdown(self): + if hasattr(self, "moriio_wrapper") and self.moriio_wrapper: + self.moriio_wrapper.shutdown() + + if hasattr(self, "_handshake_initiation_executor"): + self._handshake_initiation_executor.shutdown(wait=False) + + if ( + hasattr(self, "_moriio_handshake_listener_t") + and self._moriio_handshake_listener_t + ): + self._moriio_handshake_listener_t.join(timeout=0) + + if hasattr(self, "zmq_context") and self.zmq_context: + self.zmq_context.destroy(linger=0) + self.zmq_context = None + + def __del__(self): + self.shutdown() + + @staticmethod + def _moriio_handshake_listener( + metadata: MoRIIOAgentMetadata, + ready_event: threading.Event, + base_port: int, + tp_rank: int, + dp_rank: int, + layer_name_to_local_kv_cache_metadata: dict, + ): + """Background thread for getting new MoRIIO handshakes.""" + + encoder = msgspec.msgpack.Encoder() + encoded_data = encoder.encode(metadata) + size_in_bytes = len(encoded_data) + logger.debug( + "Size of encoded MoRIIOAgentMetadata: %s bytes", str(size_in_bytes) + ) + + # Listen for new requests for metadata. + host = "*" + + path = make_zmq_path("tcp", host, base_port) + logger.debug("mori handshake starting listening on path: %s", path) + + with zmq_ctx(zmq.ROUTER, path) as sock: + ready_event.set() + while True: + identity, msg = sock.recv_multipart() + if ( + msg != MoRIIOConstants.GET_META_MSG + and msg != MoRIIOConstants.POP_DONE_RECV + ): + logger.error("Connection listener got unexpected message") + raise HandshakeError("handshake failed, unexpected msg type") + elif msg == MoRIIOConstants.GET_META_MSG: + sock.send_multipart( + (identity, b"", encoded_data) + ) # send local mori io engine meta data + logger.debug("MoRIIO handshake listener sent metadata") + # now we send tensor meta data for each block + buf = msgpack.dumps(layer_name_to_local_kv_cache_metadata) + sock.send_multipart((identity, b"", buf)) + elif msg == MoRIIOConstants.POP_DONE_RECV: + _, req_id = sock.recv_multipart() + logger.debug( + "MoRIIO handshake listener received done recv for req", + req_id.decode(), + ) + + def _moriio_handshake( + self, + host: str, + port: int, + remote_tp_size: int, + expected_engine_id: str, + remote_dp_rank: int = 0, + ) -> set[str]: + """Do a MoRIIO handshake with a remote instance.""" + + start_time = time.perf_counter() + + # NOTE(rob): we need each rank to have a unique port. This is + # a hack to keep us moving. We will switch when moving to etcd + # or where we have a single ZMQ socket in the scheduler. + + port_offset = get_port_offset(remote_dp_rank, self.tp_rank) + path = make_zmq_path("tcp", host, port + port_offset) + logger.debug("handshake Querying metadata on path: %s", path) + + # Send query for the request. + with zmq_ctx(zmq.DEALER, path) as sock: + logger.debug("prepare send msg INSTAZNCE: %s", path) + sock.send(MoRIIOConstants.GET_META_MSG) + received_frame = sock.recv_multipart() + if len(received_frame) != 2 or received_frame[0] != b"": + raise HandshakeError(f"Unexpected frame! {received_frame = }") + + metadata_bytes = received_frame[1] + decoder = msgspec.msgpack.Decoder(MoRIIOAgentMetadata) + metadata = decoder.decode(metadata_bytes) + got_metadata_time = time.perf_counter() + logger.info( + "MoRIIO handshake: get metadata took: %s", + got_metadata_time - start_time, + ) + + self.moriio_wrapper.remote_engine_ip = host + remote_agent_name = self.moriio_wrapper.register_remote_engine( + metadata.agent_metadata + ) + + logger.debug( + "MoRIIO handshake: registered" + "remote agent %s for engine ID %s, path = %s", + remote_agent_name, + expected_engine_id, + path, + ) + + if len(self.local_kv_cache_metadata) > 0: + logger.warning( + "len(self.local_kv_cache_metadata) = %s," + "maybe you didnt clear this buffer correctly", + len(self.local_kv_cache_metadata), + ) + self.local_kv_cache_metadata = [] + if len(self.remote_kv_cache_metadata) > 0: + logger.warning( + "len(self.remote_kv_cache_metadata) = %s," + "maybe you didnt clear this buffer correctly", + len(self.remote_kv_cache_metadata), + ) + self.remote_kv_cache_metadata = [] + + received_frame = sock.recv_multipart() + if len(received_frame) != 2 or received_frame[0] != b"": + raise HandshakeError(f"unexpected frame! {received_frame = }") + buf = received_frame[1] + self.layer_name_to_remote_kv_cache_metadata[expected_engine_id] = ( + msgpack.loads(buf) + ) + self.remote_moriio_metadata[expected_engine_id] = metadata + setup_agent_time = time.perf_counter() + logger.debug( + "MoRIIO handshake: add agent took: %s", + setup_agent_time - got_metadata_time, + ) + + return {remote_agent_name} + + def _background_moriio_handshake( + self, req_id: str, remote_engine_id: EngineId, meta: ReqMeta + ): + # Do MoRIIO handshake in background and add to _ready_requests when done. + fut = None + if remote_engine_id is not None: + fut = self._handshake_futures.get(remote_engine_id) + if fut is None: + host = meta.remote_host + port = int(meta.remote_handshake_port) + tp_size = int(meta.tp_size) + remote_dp_size = int(meta.remote_dp_size) + + def request_ready(_f: Future[Any], entry=(req_id, meta)): + logger.info("MoRIIO handshake done for request %s", req_id) + self._ready_requests.put(entry) + self.load_ready_flag[remote_engine_id] = True + self.write_ready_flags[remote_engine_id] = True + + fut_list = [] + + # In dp(prefill)<->dp(decode) communication, we require an all-to-all handshake. + + for cur_dp_rank in range(remote_dp_size): + dp_engine_id = self.get_engine_name_with_dp(remote_engine_id, cur_dp_rank) + future = self._handshake_initiation_executor.submit( + self._moriio_handshake, host, port, tp_size, dp_engine_id, cur_dp_rank + ) + fut_list.append(future) + + def done_callback(f: Future[set[str]], eid=dp_engine_id): + with self._handshake_lock: + self._handshake_futures.pop(eid, None) + try: + self._remote_agents[eid] = f.result() + except Exception: + logger.exception("Handshake with %s failed", eid) + + future.add_done_callback(done_callback) + self._handshake_futures[dp_engine_id] = future + + # fut = fut_list + def wait_all_dp(): + for future in fut_list: + future.result() + return True + + all_done_future = self._handshake_initiation_executor.submit(wait_all_dp) + all_done_future.add_done_callback(request_ready) + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + """Register the KV Cache data in moriio.""" + + _, first_kv_cache = next(iter(kv_caches.items())) + kv_elem_size = first_kv_cache.element_size() + + use_mla = len(first_kv_cache.shape) == 3 + assert use_mla == self.use_mla + + if use_mla: + # MLA case. + self.num_blocks = first_kv_cache.shape[0] + block_rank = 2 # [block_size, latent_dim] + block_shape = first_kv_cache.shape[-block_rank:] + block_size, kv_latent_dim = block_shape + self.slot_size_bytes = kv_elem_size * kv_latent_dim + else: + # [2 (k and v), num_blocks, ...] + self.num_blocks = first_kv_cache.shape[1] + block_rank = 3 # [block_size, kv_heads, head_dim] + block_shape = first_kv_cache.shape[-block_rank:] + block_size, n_kv_heads, head_dim = block_shape[-3:] + # head size in bytes. + self.slot_size_bytes = ( + kv_elem_size * n_kv_heads * head_dim + ) # 1 token 1 layer size , slot size + assert block_size == self.block_size + # TODO(tms): self.block_len needs to be per-layer for sliding window, + # hybrid attn, etc + # block size in bytes + self.block_len = kv_elem_size * math.prod(block_shape) + self.kv_cache_shape = first_kv_cache.shape + self.block_shape = block_shape + self.kv_element_size = kv_elem_size + + self.dst_num_blocks[self.engine_id] = self.num_blocks + self.kv_caches = kv_caches # layer name to kv cache + kv_caches_base_addr = [] + caches_data = [] + + for cache_or_caches in kv_caches.values(): + cache_list = [cache_or_caches] if use_mla else cache_or_caches + for cache in cache_list: + base_addr = cache.data_ptr() + region_len = self.num_blocks * self.block_len + caches_data.append((base_addr, region_len, cache.device.index, "")) + kv_caches_base_addr.append(base_addr) + + for layer_name, kv_cache in kv_caches.items(): + if layer_name not in self.layer_name_to_local_kv_cache_metadata: + self.layer_name_to_local_kv_cache_metadata[layer_name] = [] + + moriio_mem_metadata = self.moriio_wrapper.register_local_tensor(kv_cache) + self.layer_name_to_local_kv_cache_metadata[layer_name].append( + moriio_mem_metadata + ) + + self.local_kv_cache_size.append(cache.nelement() * cache.element_size()) + + self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr + self.num_regions = len(caches_data) + self.num_layers = len(self.kv_caches.keys()) + + # Optimization for models with local attention (Llama 4) + if self.vllm_config.model_config.hf_config.model_type == "llama4": + from transformers import Llama4TextConfig + + assert isinstance( + self.vllm_config.model_config.hf_text_config, Llama4TextConfig + ) + llama4_config = self.vllm_config.model_config.hf_text_config + no_rope_layers = llama4_config.no_rope_layers + chunk_size = llama4_config.attention_chunk_size + chunk_block_size = math.ceil(chunk_size / self.block_size) + for layer_idx in range(self.num_layers): + # no_rope_layers[layer_idx] == 0 means NoPE (global) + # Any other value means RoPE (local chunked) + is_local_attention = no_rope_layers[layer_idx] != 0 + block_window = chunk_block_size if is_local_attention else None + self.block_window_per_layer.append(block_window) + logger.debug( + "Llama 4 block window per layer mapping: %s", + self.block_window_per_layer, + ) + assert len(self.block_window_per_layer) == self.num_layers + + metadata = MoRIIOAgentMetadata( + engine_id=self.engine_id, + agent_metadata=self.moriio_wrapper.get_agent_metadata(), + kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], + num_blocks=self.num_blocks, + block_len=self.block_len, + attn_backend_name=self.backend_name, + ) + ready_event = threading.Event() + self._moriio_handshake_listener_t = threading.Thread( + target=self._moriio_handshake_listener, + args=( + metadata, + ready_event, + self.side_channel_port, + self.tp_rank, + self.dp_rank, + self.layer_name_to_local_kv_cache_metadata, + ), + daemon=True, + name="moriio_handshake_listener", + ) + self._moriio_handshake_listener_t.start() + ready_event.wait() # Wait for listener ZMQ socket to be ready. + self.moriio_wrapper.async_wait_reqid() + + def get_finished(self) -> tuple[set[str], set[str]]: + """ + Get requests that are done sending or recving on this specific worker. + The scheduler process (via the MultiprocExecutor) will use this output + to track which workers are done. + """ + + done_sending, done_recving = set(), set() + + if self.is_producer: + done_sending = self.moriio_wrapper.pop_finished_req_ids() + + else: + if self.mode == MoRIIOMode.WRITE: + done_recving = self.moriio_wrapper.pop_finished_write_req_ids() + else: + done_recving = self._pop_done_transfers() + + return done_sending, done_recving + + def _pop_done_transfers(self) -> set[str]: + done_req_ids: set[str] = set() + with self.moriio_wrapper.lock: + to_remove = [] + for req_id, status_list in self._recving_transfers.items(): + if status_list[-1].Succeeded(): + done_req_ids.add(req_id) + + self.moriio_wrapper.send_notify( + req_id, + self._recving_transfers_callback_addr[req_id][0], + self._recving_transfers_callback_addr[req_id][1], + ) + to_remove.append(req_id) + for req_id in to_remove: + del self._recving_transfers[req_id] + del self._recving_transfers_callback_addr[req_id] + + return done_req_ids + + def save_kv_layer( + self, + metadata: MoRIIOConnectorMetadata, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", + **kwargs, + ): + if not self.is_producer: + return + if self.mode == MoRIIOMode.READ: + return + remote_engine_id = None + + for req_id, meta in metadata.reqs_to_save.items(): + # we only need to check if dp0 in rank + remote_engine_id = ( + str(meta.remote_host) + ":" + str(meta.remote_handshake_port) + ) + + meta.remote_engine_id = remote_engine_id + + dp0_remote_engine_id = self.get_engine_name_with_dp(remote_engine_id, 0) + if dp0_remote_engine_id not in self._remote_agents: + # Initiate handshake with remote engine to exchange metadata. + with self._handshake_lock: + if remote_engine_id not in self._remote_agents: + self._background_moriio_handshake( + req_id, remote_engine_id, meta + ) + + continue + self._write_blocks_for_req(req_id, meta, layer_name, kv_layer) + + while True: + if ( + self._ready_requests.empty() + and remote_engine_id not in self.write_ready_flags + ): + continue + elif not self._ready_requests.empty() and ( + remote_engine_id in self.write_ready_flags + ): + self._write_blocks_for_req( + *self._ready_requests.get_nowait(), layer_name, kv_layer + ) + break + else: + break + + def get_engine_name_with_dp(self, engine_name, dp_rank): + return f"{engine_name}_dp{dp_rank}" + + def start_load_kv(self, metadata: MoRIIOConnectorMetadata): + """ + Start loading by triggering non-blocking moriio_xfer. + We check for these trnxs to complete in each step(). + """ + if self.is_producer: + self.moriio_wrapper.async_wait_reqid() + return + if self.mode == MoRIIOMode.WRITE: + return + + wait_handshake_readd_req = False + remote_engine_id = None + + for req_id, meta in metadata.reqs_to_recv.items(): + remote_engine_id = ( + str(meta.remote_host) + ":" + str(meta.remote_handshake_port) + ) + meta.remote_engine_id = remote_engine_id + dp0_remote_engine_id = self.get_engine_name_with_dp(remote_engine_id, 0) + if dp0_remote_engine_id not in self._remote_agents: + # Initiate handshake with remote engine to exchange metadata. + with self._handshake_lock: + if remote_engine_id not in self._remote_agents: + self._background_moriio_handshake( + req_id, remote_engine_id, meta + ) + wait_handshake_readd_req = True + + continue + + # Handshake already completed, start async read xfer. + self._read_blocks_for_req(req_id, meta) + # Start transfers for requests whose handshakes have now finished. + + while True: + if ( + self._ready_requests.empty() + and remote_engine_id not in self.load_ready_flag + and wait_handshake_readd_req + ): + continue + elif ( + not self._ready_requests.empty() + and remote_engine_id in self.load_ready_flag + ): + self._read_blocks_for_req(*self._ready_requests.get_nowait()) + break + else: + break + + self._reqs_to_send.update(metadata.reqs_to_send) + + def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): + logger.debug( + "Remote agent %s available, calling _read_blocks for req %s", + meta.remote_engine_id, + req_id, + ) + self._read_blocks( + request_id=req_id, + dst_engine_id=meta.remote_engine_id, + local_block_ids=meta.local_block_ids, + remote_block_ids=meta.remote_block_ids, + remote_host=meta.remote_host, + remote_notify_port=meta.remote_notify_port, + ) + + def _write_blocks_for_req(self, req_id: str, meta: ReqMeta, layer_name, kv_layer): + self.schedule_write_blocks( + request_id=req_id, + dst_engine_id=meta.remote_engine_id, + local_block_ids=meta.local_block_ids, + remote_block_ids=meta.remote_block_ids, + layer_name=layer_name, + kv_layer=kv_layer, + remote_notify_port=meta.remote_notify_port, + remote_ip=meta.remote_host, + ) + + def _is_last_layer(self, layer_name): + return layer_name == list(self.kv_caches.keys())[-1] + + def merge_contiguous_blocks( + self, + offsets_local: list[int], + offsets_remote: list[int], + sizes: list[int], + assume_sorted: bool = False, + ) -> tuple[list[int], list[int], list[int]]: + n = len(offsets_local) + if n == 0: + return [], [], [] + if not (n == len(offsets_remote) == len(sizes)): + raise ValueError("Input list lengths mismatch") + local_arr = np.fromiter(offsets_local, dtype=np.int64, count=n) + remote_arr = np.fromiter(offsets_remote, dtype=np.int64, count=n) + sizes_arr = np.fromiter(sizes, dtype=np.int64, count=n) + + if assume_sorted: + local_sorted = local_arr + remote_sorted = remote_arr + sizes_sorted = sizes_arr + else: + if np.all(local_arr[:-1] <= local_arr[1:]): + local_sorted = local_arr + remote_sorted = remote_arr + sizes_sorted = sizes_arr + else: + sort_idx = np.argsort(local_arr, kind="stable") + local_sorted = local_arr[sort_idx] + remote_sorted = remote_arr[sort_idx] + sizes_sorted = sizes_arr[sort_idx] + + if n == 1: + return ( + [int(local_sorted[0])], + [int(remote_sorted[0])], + [int(sizes_sorted[0])], + ) + + diff_local = local_sorted[1:] - local_sorted[:-1] + diff_remote = remote_sorted[1:] - remote_sorted[:-1] + prev_size = sizes_sorted[:-1] + + contiguous = (diff_local == prev_size) & (diff_remote == prev_size) + + if not contiguous.any(): + return local_sorted.tolist(), remote_sorted.tolist(), sizes_sorted.tolist() + + if contiguous.all(): + total_size = int(sizes_sorted.sum()) + return [int(local_sorted[0])], [int(remote_sorted[0])], [total_size] + + break_positions = np.flatnonzero(~contiguous) + 1 + segment_starts = np.concatenate(([0], break_positions)) + segment_ends = np.concatenate((break_positions, [n])) + + seg_count = len(segment_starts) + merged_local = [0] * seg_count + merged_remote = [0] * seg_count + merged_sizes = [0] * seg_count + + for si in range(seg_count): + s = segment_starts[si] + e = segment_ends[si] + merged_local[si] = int(local_sorted[s]) + merged_remote[si] = int(remote_sorted[s]) + + merged_sizes[si] = int( + local_sorted[e - 1] + sizes_sorted[e - 1] - local_sorted[s] + ) + + return merged_local, merged_remote, merged_sizes + + def _compute_block_transfer_offsets( + self, + layer_name: str, + local_block_ids: list[int], + remote_block_ids: list[int], + remote_moriio_meta: MoRIIOAgentMetadata, + ) -> tuple[list[int], list[int], list[int]]: + """Compute transfer offsets for block data. + + Args: + layer_name: Name of the layer to transfer + local_block_ids: IDs of local blocks + remote_block_ids: IDs of remote blocks + remote_moriio_meta: Metadata of the remote MoRIIO agent + Returns: + Tuple of (local_offsets, remote_offsets, transfer_sizes) + """ + assert self.kv_cache_shape is not None, "KV caches shape not initialized" + is_mla = len(self.kv_cache_shape) == 3 + stride = self.kv_caches[layer_name].stride() + sz = self.kv_caches[layer_name].element_size() + if is_mla: + blknum, blksize, hs = self.kv_cache_shape + hn = 1 + block_stride = stride[0] + else: + _, blknum, blksize, hn, hs = self.kv_cache_shape + local_ktov_stride = stride[0] + block_stride = stride[1] + remote_ktov_stride = block_stride * remote_moriio_meta.num_blocks + + transfer_size_byte = blksize * hn * hs * sz + per_block = 1 if is_mla else 2 + total = len(local_block_ids) * per_block + offset_local = [0] * total + offset_remote = [0] * total + sizes = [transfer_size_byte] * total + + w = 0 + for i, lb in enumerate(local_block_ids): + rb = remote_block_ids[i] + # K + offset_local[w] = sz * (lb * block_stride) + offset_remote[w] = sz * (rb * block_stride) + w += 1 + if not is_mla: + # V + # Handle num_block variations originating from PD (different kv strides) + # TODO: address block_sz differences in heterogeneous TP scenarios + # In MLA, we don't need to consider these two cases. + offset_local[w] = sz * (1 * local_ktov_stride + lb * block_stride) + offset_remote[w] = sz * (1 * remote_ktov_stride + rb * block_stride) + w += 1 + + merged_l, merged_r, merged_s = self.merge_contiguous_blocks( + offset_local, offset_remote, sizes, assume_sorted=False + ) + return merged_l, merged_r, merged_s + + def _read_blocks( + self, + local_block_ids: list[int], + remote_block_ids: list[int], + dst_engine_id: str, + request_id: str, + remote_host: str, + remote_notify_port: int, + ) -> None: + if self.mode == MoRIIOMode.WRITE: + return + + dp0_engine_id = self.get_engine_name_with_dp(dst_engine_id, 0) + sessions, remote_moriio_meta = self._get_built_session(dp0_engine_id) + + first_layer = list(self.layer_name_to_local_kv_cache_metadata.keys())[0] + offs = self._compute_block_transfer_offsets( + first_layer, local_block_ids, remote_block_ids, remote_moriio_meta + ) + + for layer_name in self.layer_name_to_local_kv_cache_metadata: + sess_idx = list(self.layer_name_to_local_kv_cache_metadata.keys()).index( + layer_name + ) + # TODO : apply multi-session batch-read when moriio support it + transfer_status = self.moriio_wrapper.read_remote_data( + offs[2], offs[0], offs[1], sessions[sess_idx] + ) + with self.moriio_wrapper.lock: + self._recving_transfers[request_id].append(transfer_status) + self._recving_transfers_callback_addr[request_id] = ( + remote_host, + str(remote_notify_port + self.tp_rank), + ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_engine.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_engine.py new file mode 100644 index 0000000000000..3a35c22622b89 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_engine.py @@ -0,0 +1,609 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import threading +from typing import TYPE_CHECKING, Any, Optional +from weakref import ref as weakref_ref + +import msgpack +import torch +import zmq + +from vllm import envs +from vllm.logger import init_logger +from vllm.utils.network_utils import ( + make_zmq_path, + make_zmq_socket, +) + +if TYPE_CHECKING: + pass + +from queue import Empty, Queue + +from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import ( + ROLE, + HandshakeError, + LayerTransferPlan, + MoRIIOAgentMetadata, + MoRIIOConstants, + MoRIIOError, + RemoteAllocInfo, + TransferError, + WriteTask, + get_port_offset, + get_role, + zmq_ctx, +) + +if TYPE_CHECKING: + from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector import ( + MoRIIOConnectorWorker, + ) + +logger = init_logger(__name__) +try: + from mori.io import ( + EngineDesc, + IOEngine, + MemoryDesc, + PollCqMode, + RdmaBackendConfig, + ) + + logger.info("MoRIIO is available") +except ImportError: + logger.error("MoRIIO is not available") + + +"""Write task execution logic for MoRIIO connector.""" + + +class MoRIIOWriter: + """Handles write operations for KV cache transfers. + Implements distributed KV cache transfer using the MoRIIO library + for RDMA-based communication between prefill and decode instances.""" + + def __init__(self, worker: "MoRIIOConnectorWorker"): + """Initialize the writer. + + Args: + worker: Reference to the parent worker + """ + self._worker_ref: weakref_ref[MoRIIOConnectorWorker] = weakref_ref(worker) + self._write_task_q: Queue[WriteTask] = Queue() + self._write_worker_started = False + self._write_worker_lock = threading.Lock() + self._deferred_tasks: list[WriteTask] = [] + + @property + def worker(self) -> "MoRIIOConnectorWorker": + """Get the worker instance. + + Returns: + The parent worker instance + + Raises: + RuntimeError: If worker has been garbage collected + """ + worker = self._worker_ref() + if worker is None: + raise RuntimeError("Parent worker has been garbage collected") + return worker + + def ensure_worker_started(self) -> None: + """Ensure the background write worker is running.""" + if self._write_worker_started: + return + self._write_worker_started = True + with self._write_worker_lock: + thread = threading.Thread( + target=self._write_worker_loop, daemon=True, name="moriio-write-worker" + ) + thread.start() + logger.info("Started MoRIIO write worker thread") + + def schedule_write(self, task: WriteTask) -> None: + """Schedule a write task. + + Args: + task: The write task to schedule + """ + self.ensure_worker_started() + self._write_task_q.put(task) + + def _write_worker_loop(self) -> None: + """Main loop for the write worker thread.""" + + while True: + # Process deferred tasks first + self._process_deferred_tasks() + + # Get new task + try: + task = self._write_task_q.get(timeout=0.01) + except Empty: + continue + + # Check if remote blocks are ready + if not self._is_remote_ready(task): + # task.retry_count += 1 + self._deferred_tasks.append(task) + # logger.debug( + # "Deferred task for request %s (retry %d)", + # task.request_id, task.retry_count + # ) + continue + + # Execute the task + + self._execute_write_task(task) + + def _process_deferred_tasks(self) -> None: + """Process tasks that were previously deferred.""" + if not self._deferred_tasks: + return + + still_deferred: list[WriteTask] = [] + for task in self._deferred_tasks: + if self._is_remote_ready(task): + self._execute_write_task(task) + else: + still_deferred.append(task) + + self._deferred_tasks = still_deferred + + def _is_remote_ready(self, task: WriteTask) -> bool: + """Check if remote blocks are allocated for this task. + + Args: + task: The write task + + Returns: + True if remote blocks are ready + """ + return ( + task.request_id in self.worker.moriio_wrapper.done_remote_allocate_req_dict + ) + + def _get_remote_alloc_info(self, request_id: str) -> RemoteAllocInfo: + """Get remote allocation info for a request. + + Args: + request_id: The request ID + + Returns: + Remote allocation information + + Raises: + KeyError: If allocation info is missing + """ + try: + return self.worker.moriio_wrapper.done_remote_allocate_req_dict[request_id] + except KeyError as e: + raise KeyError( + f"Remote allocation info missing for request {request_id}" + ) from e + + def _execute_write_task(self, task: WriteTask) -> None: + """Execute a single write task. + + Args: + task: The write task to execute + + """ + # Get remote allocation info + request_info = self._get_remote_alloc_info(task.request_id) + + if request_info.block_ids is None: + logger.debug("Request %s remote block IDs not ready", task.request_id) + return + + # Wait for CUDA event + # The attention computation of the current layer cannot + # overlap with the kv transfer task, + # otherwise it will cause precision issues. + # This event is used to synchronize the kv transfer and computation tasks. + task.event.synchronize() + + # Update engine ID with DP rank + task.dst_engine_id = self.worker.get_engine_name_with_dp( + task.dst_engine_id, request_info.decode_dp_rank + ) + + # Get or create sessions + sessions, remote_moriio_meta = self.worker._get_built_session( + task.dst_engine_id + ) + + # Prepare transfer plan + plan = self._prepare_transfer_plan(task, request_info, remote_moriio_meta) + + # Execute transfer + self._do_layer_write(plan, sessions) + + # Finalize if all layers complete + self._finalize_if_complete(task, request_info) + + def _prepare_transfer_plan( + self, + task: WriteTask, + request_info: RemoteAllocInfo, + remote_moriio_meta: MoRIIOAgentMetadata, + ) -> LayerTransferPlan: + """Prepare the transfer plan for a layer. + + Args: + task: The write task + request_info: Remote allocation information + + Returns: + The transfer plan + """ + # Compute offsets if not cached + if request_info.transfer_offset is None: + offsets = self.worker._compute_block_transfer_offsets( + task.layer_name, + task.local_block_ids, + request_info.block_ids, + remote_moriio_meta, + ) + request_info.transfer_offset = offsets + + # Get session index + layer_names = list(self.worker.layer_name_to_local_kv_cache_metadata.keys()) + sess_idx = layer_names.index(task.layer_name) + + local_off, remote_off, sizes = request_info.transfer_offset + + return LayerTransferPlan( + request_id=task.request_id, + layer_name=task.layer_name, + sess_idx=sess_idx, + transfer_local_offsets=local_off, + transfer_remote_offsets=remote_off, + transfer_sizes=sizes, + use_batch=True, + ) + + def _do_layer_write(self, plan: LayerTransferPlan, sessions: list) -> None: + """Perform the actual layer write. + + Args: + plan: The transfer plan + sessions: List of transfer sessions + """ + if plan.use_batch: + self.worker.moriio_wrapper.write_remote_data( + plan.transfer_sizes, + plan.transfer_local_offsets, + plan.transfer_remote_offsets, + sessions[plan.sess_idx], + ) + else: + for i in range(len(plan.transfer_local_offsets)): + self.worker.moriio_wrapper.write_remote_data_single( + plan.transfer_sizes[i], + plan.transfer_local_offsets[i], + plan.transfer_remote_offsets[i], + plan.sess_idx, + ) + + def _finalize_if_complete( + self, task: WriteTask, request_info: RemoteAllocInfo + ) -> None: + """Finalize transfer if all layers are complete. + + Args: + task: The write task + request_info: Remote allocation information + """ + request_info.writes_done += 1 + + if request_info.writes_done >= self.worker.num_layers: + # Wait for transfer to complete + self.worker.moriio_wrapper.waiting_for_transfer_complete() + + remote_port = task.remote_notify_port + get_port_offset( + request_info.decode_dp_rank, self.worker.tp_rank + ) + # Consider using RDMA immediate data in decode side + # to eliminate the need for this notification. + # Consider including the first gen token from prefill in the notification + + # Send completion notification + self.worker.moriio_wrapper.send_notify( + task.request_id, task.remote_ip, remote_port + ) + # mark request as done, then we can free the blocks + with self.worker.moriio_wrapper.lock: + self.worker.moriio_wrapper.done_req_ids.append(task.request_id) + del self.worker.moriio_wrapper.done_remote_allocate_req_dict[ + task.request_id + ] + logger.debug( + "Completed transfer for request %s, notified port %d", + task.request_id, + remote_port, + ) + + +class MoRIIOWrapper: + """Wrapper for MoRIIO engine operations. + + Handles both producer and consumer roles for KV cache transfers. + + Args: + moriio_engine: MoRIIO engine instance + tp_rank: Tensor parallel rank + dp_rank: Data parallel rank + """ + + def __init__( + self, + moriio_engine: Optional["IOEngine"] = None, + tp_rank: int = 0, + dp_rank: int = 0, + ): + self.tp_rank = tp_rank + self.dp_rank = dp_rank + self.moriio_engine = moriio_engine + self.remote_memory_metadata = None + self.local_memory_registered = False + self.local_memory_metadata = None + self.transfer_status: list[Any] = [] + self.remote_engine_ip: str | None = None + self.notify_port: int | None = None + self.lock = threading.Lock() + self.done_req_ids: list[str] = [] + self.done_remote_allocate_req_dict: dict[str, RemoteAllocInfo] = {} + self.done_write_cache_req_ids: list[str] = [] + self.notify_thread: threading.Thread | None = None + self.sessions: list[IOEngine.Session] = [] + self.paths: dict[str, zmq.Socket] = {} + + def set_moriio_engine(self, moriio_engine): + assert moriio_engine is not None, ( + "You Cannot pass None engine to MoRIIOWrapper!" + ) + self.moriio_engine = moriio_engine + + def set_backend_type(self, backend_type): + assert self.moriio_engine is not None, "MoRIIO engine must be set first" + qp_per_transfer = envs.VLLM_MORIIO_QP_PER_TRANSFER + post_batch_size = envs.VLLM_MORIIO_POST_BATCH_SIZE + num_worker_threads = envs.VLLM_MORIIO_NUM_WORKERS + poll_mode = PollCqMode.POLLING + rdma_cfg = RdmaBackendConfig( + qp_per_transfer, + post_batch_size, + num_worker_threads, + poll_mode, + ) + self.moriio_engine.create_backend(backend_type, rdma_cfg) + + def get_agent_metadata(self): + assert self.moriio_engine is not None, "MoRIIO engine must be set first" + engine_metadata = self.moriio_engine.get_engine_desc() + engine_metadata_packed = engine_metadata.pack() + return engine_metadata_packed + + def register_remote_engine(self, remote_packed_engine_metadata): + assert self.moriio_engine is not None, "MoRIIO engine must be set first" + consumer_engine_metadata = EngineDesc.unpack(remote_packed_engine_metadata) + self.moriio_engine.register_remote_engine(consumer_engine_metadata) + return consumer_engine_metadata.key + + def register_local_tensor(self, tensor: torch.Tensor): + assert self.moriio_engine is not None, "MoRIIO engine must be set first" + try: + self.local_memory_metadata = self.moriio_engine.register_torch_tensor( + tensor + ) + assert self.local_memory_metadata is not None, ( + "register_torch_tensor returned None" + ) + local_memory_metadata_packed = self.local_memory_metadata.pack() + except Exception as e: + raise MoRIIOError(f"Failed to register local memory: {e}") from e + self.local_memory_registered = True + return local_memory_metadata_packed + + def get_unpack_memory_metadata(self, packed_memory_metadata): + return MemoryDesc.unpack(packed_memory_metadata) + + def build_session(self, local_memory_metadata, remote_memory_metadata): + assert self.moriio_engine is not None, "MoRIIO engine must be set first" + return self.moriio_engine.create_session( + local_memory_metadata, remote_memory_metadata + ) + + def read_remote_data( + self, transfer_size_byte, local_offset=0, remote_offset=0, session=None + ): + assert self.local_memory_registered, "You have not register local memory data!" + assert self.moriio_engine is not None, "MoRIIO engine must be set first" + transfer_status = session.batch_read( + local_offset, + remote_offset, + transfer_size_byte, + self.moriio_engine.allocate_transfer_uid(), + ) + + return transfer_status + + def write_remote_data( + self, transfer_size_byte, local_offset=0, remote_offset=0, session=None + ): + assert self.local_memory_registered, "You have not register local memory data!" + assert self.moriio_engine is not None, "MoRIIO engine must be set first" + write_uid = self.moriio_engine.allocate_transfer_uid() + + transfer_status = session.batch_write( + local_offset, remote_offset, transfer_size_byte, write_uid + ) + with self.lock: + self.transfer_status.append(transfer_status) + + def write_remote_data_single( + self, transfer_size_byte, local_offset=0, remote_offset=0, sess_idx=0 + ): + assert self.local_memory_registered, "You have not register local memory data!" + assert self.moriio_engine is not None, "MoRIIO engine must be set first" + transfer_status = self.sessions[sess_idx].write( + local_offset, + remote_offset, + transfer_size_byte, + self.moriio_engine.allocate_transfer_uid(), + ) + with self.lock: + self.transfer_status.append(transfer_status) + + def waiting_for_transfer_complete(self): + if not self.transfer_status: + return + + transfers_to_wait = [] + with self.lock: + transfers_to_wait = self.transfer_status[:] + self.transfer_status.clear() + + for status in transfers_to_wait: + try: + status.Wait() + if not status.Succeeded(): + logger.error( + "Transfer failed: %s, Code: %s", status.Message(), status.Code() + ) + raise TransferError("MoRIIO transfer failed!") + except Exception as e: + logger.error("Transfer %s failed: %s", status, e) + raise + + def async_wait_reqid(self): + assert self.notify_port is not None, "Notify port cannot be None" + + if self.notify_thread is not None: + return + + def _async_wait(): + host = "*" + path = make_zmq_path("tcp", host, self.notify_port) + logger.info("Node starting to listen notify from path = %s", path) + + with zmq_ctx(zmq.ROUTER, path) as sock: + while True: + try: + identity, msg = sock.recv_multipart() + self._handle_message(msg) + except Exception as e: + logger.error("Error processing message: %s", e) + raise HandshakeError(f"Error processing message: {e}") from e + + self.notify_thread = threading.Thread( + target=_async_wait, daemon=True, name="moriio-notify-listener" + ) + self.notify_thread.start() + + def _handle_message(self, msg: bytes): + """Handles incoming messages from remote nodes.""" + # Handles incoming remote messages: + # Prefill Role: + # [write] mode: receives block information (allocation) + # [read] mode: receives block release messages from decode side + # Decode Role: + # [write] mode: receives KV cache write completion notifications + handled = False + try: + data = msgpack.loads(msg) + if isinstance(data, dict) and "req_id" in data: + self._handle_structured_message(data) + + return + except (msgpack.exceptions.ExtraData, msgpack.exceptions.UnpackException): + logger.debug("Failed to decode msgpack message, will try as string") + pass + + try: + msg_str = msg.decode("UTF-8") + if msg_str.startswith(MoRIIOConstants.COMPLETION_PREFIX): + self._handle_completion_message(msg_str) + handled = True + except UnicodeDecodeError: + logger.warning("Received non-UTF8 message: %s", msg_str) + if not handled: + raise MoRIIOError(f"Unhandled message format: {msg_str}") + + def _handle_structured_message(self, data: dict): + assert get_role() == ROLE.PRODUCER, "Only prefill can get block messages" + req_id = data["req_id"] + block_notify_list = data.get("block_notify_list", []) + decode_dp_rank = data.get("decode_rank", 0) + assert len(block_notify_list) > 0, ( + "block_notify_list cannot be empty in remote allocate message" + ) + + with self.lock: + self.done_remote_allocate_req_dict[req_id] = RemoteAllocInfo( + block_ids=block_notify_list, decode_dp_rank=decode_dp_rank + ) + + def _handle_completion_message(self, msg: str): + with self.lock: + if get_role() == ROLE.PRODUCER: + self.done_req_ids.append(msg) + else: + self.done_write_cache_req_ids.append(msg) + + def send_notify(self, req_ids, remote_ip, remote_port): + if not remote_ip or not remote_port: + logger.warning("Missing remote_ip or remote_port for notification") + return + + path = make_zmq_path("tcp", remote_ip, remote_port) + + if path not in self.paths: + ctx = zmq.Context.instance() + sock = make_zmq_socket( + ctx=ctx, path=path, socket_type=zmq.DEALER, bind=False + ) + self.paths[path] = sock + + req_list = req_ids if isinstance(req_ids, list) else [req_ids] + + sock = self.paths[path] + try: + for req_id in req_list: + if not isinstance(req_id, str): + logger.warning( + "Invalid req_id type: %s, expected str", type(req_id) + ) + continue + sock.send(req_id.encode("utf-8")) + except Exception as e: + logger.error("Failed to send notification to %s: %s", path, e) + self.paths.pop(path, None) + raise + + def pop_finished_req_ids(self): + # producer invocation: get the set of completed requests at the decode + with self.lock: + done_send = set(self.done_req_ids) + self.done_req_ids = [] + return done_send + + def pop_finished_write_req_ids(self): + # Call the consumer in write mode to get the collection after write completion + with self.lock: + done_write_cache = set(self.done_write_cache_req_ids) + self.done_write_cache_req_ids = [] + return done_write_cache + + def shutdown(self): + logger.debug("Closing MoRIIOWrapper and cleaning up ZMQ sockets") + for path, sock in self.paths.items(): + try: + sock.close(linger=0) + logger.debug("Closed ZMQ socket for path: %s", path) + except Exception as e: + logger.warning("Error closing ZMQ socket for path %s: %s", path, e) + self.paths.clear() diff --git a/vllm/envs.py b/vllm/envs.py index 1d4128d74b95c..bc084a1ef1431 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -203,6 +203,10 @@ if TYPE_CHECKING: VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: int | None = None VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 480 + VLLM_MORIIO_CONNECTOR_READ_MODE: bool = False + VLLM_MORIIO_QP_PER_TRANSFER: int = 1 + VLLM_MORIIO_POST_BATCH_SIZE: int = -1 + VLLM_MORIIO_NUM_WORKERS: int = 1 VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT: int = 480 VLLM_USE_CUDNN_PREFILL: bool = False VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False @@ -1379,6 +1383,20 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_NIXL_ABORT_REQUEST_TIMEOUT": lambda: int( os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "480") ), + # Controls the read mode for the Mori-IO connector + "VLLM_MORIIO_CONNECTOR_READ_MODE": lambda: ( + os.getenv("VLLM_MORIIO_CONNECTOR_READ_MODE", "False").lower() in ("true", "1") + ), + # Controls the QP (Queue Pair) per transfer configuration for the Mori-IO connector + "VLLM_MORIIO_QP_PER_TRANSFER": lambda: int( + os.getenv("VLLM_MORIIO_QP_PER_TRANSFER", "1") + ), + # Controls the post-processing batch size for the Mori-IO connector + "VLLM_MORIIO_POST_BATCH_SIZE": lambda: int( + os.getenv("VLLM_MORIIO_POST_BATCH_SIZE", "-1") + ), + # Controls the number of workers for Mori operations for the Mori-IO connector + "VLLM_MORIIO_NUM_WORKERS": lambda: int(os.getenv("VLLM_MORIIO_NUM_WORKERS", "1")), # Timeout (in seconds) for MooncakeConnector in PD disaggregated setup. "VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT": lambda: int( os.getenv("VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT", "480")