mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-02 16:57:04 +08:00
Merge 8b5e2e69fbdc791265364ecb64569bcbf962bc81 into 254f6b986720c92ddf97fbb1a6a6465da8e87e29
This commit is contained in:
commit
c09ddbeeb7
@ -11,6 +11,8 @@ ARG FA_BRANCH="0e60e394"
|
||||
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
|
||||
ARG AITER_BRANCH="6af8b687"
|
||||
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
|
||||
ARG MORI_BRANCH="2d02c6a9"
|
||||
ARG MORI_REPO="https://github.com/ROCm/mori.git"
|
||||
|
||||
FROM ${BASE_IMAGE} AS base
|
||||
|
||||
@ -20,6 +22,7 @@ ENV LD_LIBRARY_PATH=/opt/rocm/lib:/usr/local/lib:
|
||||
ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942;gfx950;gfx1100;gfx1101;gfx1200;gfx1201;gfx1150;gfx1151
|
||||
ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
|
||||
ENV AITER_ROCM_ARCH=gfx942;gfx950
|
||||
ENV MORI_GPU_ARCHS=gfx942;gfx950
|
||||
|
||||
# Required for RCCL in ROCm7.1
|
||||
ENV HSA_NO_SCRATCH_RECLAIM=1
|
||||
@ -33,7 +36,7 @@ ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Install Python and other dependencies
|
||||
RUN apt-get update -y \
|
||||
&& apt-get install -y software-properties-common git curl sudo vim less libgfortran5 \
|
||||
&& apt-get install -y software-properties-common git curl sudo vim less libgfortran5 libopenmpi-dev libpci-dev \
|
||||
&& for i in 1 2 3; do \
|
||||
add-apt-repository -y ppa:deadsnakes/ppa && break || \
|
||||
{ echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \
|
||||
@ -67,6 +70,18 @@ RUN cd /opt/rocm/share/amd_smi \
|
||||
&& pip wheel . --wheel-dir=dist
|
||||
RUN mkdir -p /app/install && cp /opt/rocm/share/amd_smi/dist/*.whl /app/install
|
||||
|
||||
FROM base AS build_mori
|
||||
ARG MORI_BRANCH
|
||||
ARG MORI_REPO
|
||||
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
|
||||
pip install /install/*.whl
|
||||
RUN git clone ${MORI_REPO}
|
||||
RUN cd mori \
|
||||
&& git checkout ${MORI_BRANCH} \
|
||||
&& git submodule update --init --recursive \
|
||||
&& python3 setup.py bdist_wheel --dist-dir=dist && ls /app/mori/dist/*.whl
|
||||
RUN mkdir -p /app/install && cp /app/mori/dist/*.whl /app/install
|
||||
|
||||
FROM base AS build_pytorch
|
||||
ARG PYTORCH_BRANCH
|
||||
ARG PYTORCH_VISION_BRANCH
|
||||
@ -132,6 +147,8 @@ RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
|
||||
cp /install/*.whl /app/debs
|
||||
RUN --mount=type=bind,from=build_aiter,src=/app/install/,target=/install \
|
||||
cp /install/*.whl /app/debs
|
||||
RUN --mount=type=bind,from=build_mori,src=/app/install/,target=/install \
|
||||
cp /install/*.whl /app/debs
|
||||
|
||||
FROM base AS final
|
||||
RUN --mount=type=bind,from=debs,src=/app/debs,target=/install \
|
||||
|
||||
@ -98,9 +98,24 @@ Currently, there are no pre-built ROCm wheels.
|
||||
!!! note
|
||||
- You will need to config the `$AITER_BRANCH_OR_COMMIT` for your purpose.
|
||||
- The validated `$AITER_BRANCH_OR_COMMIT` can be found in the [docker/Dockerfile.rocm_base](https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile.rocm_base).
|
||||
|
||||
|
||||
4. Build vLLM. For example, vLLM on ROCM 7.0 can be built with the following steps:
|
||||
|
||||
4. If you want to use MORI for EP or PD disaggregation, you can install [MORI](https://github.com/ROCm/mori) using the following steps:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/ROCm/mori.git
|
||||
cd mori
|
||||
git checkout $MORI_BRANCH_OR_COMMIT
|
||||
git submodule sync; git submodule update --init --recursive
|
||||
MORI_GPU_ARCHS="gfx942;gfx950" python3 install .
|
||||
```
|
||||
|
||||
!!! note
|
||||
- You will need to config the `$MORI_BRANCH_OR_COMMIT` for your purpose.
|
||||
- The validated `$MORI_BRANCH_OR_COMMIT` can be found in the [docker/Dockerfile.rocm_base](https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile.rocm_base).
|
||||
|
||||
|
||||
5. Build vLLM. For example, vLLM on ROCM 7.0 can be built with the following steps:
|
||||
|
||||
???+ console "Commands"
|
||||
|
||||
|
||||
@ -0,0 +1,320 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import socket
|
||||
import threading
|
||||
import uuid
|
||||
|
||||
import aiohttp
|
||||
import msgpack
|
||||
import zmq
|
||||
from quart import Quart, make_response, request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
prefill_instances: list[dict] = []
|
||||
decode_instances: list[dict] = []
|
||||
request_nums = 0
|
||||
app = Quart(__name__)
|
||||
|
||||
IP_PORT_PATTERN = re.compile(r"//(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}):(\d+)")
|
||||
|
||||
|
||||
TRANSFER_TYPE = None
|
||||
|
||||
|
||||
def _append_whole_dict_unique(target_list, data_dict):
|
||||
new_filtered = {k: v for k, v in data_dict.items() if k != "index"}
|
||||
for existed in target_list:
|
||||
existed_filtered = {k: v for k, v in existed.items() if k != "index"}
|
||||
if existed_filtered == new_filtered:
|
||||
return False
|
||||
print("!!APPEND!!", data_dict)
|
||||
target_list.append(data_dict)
|
||||
transfer_mode = data_dict.get("transfer_mode", "unknown")
|
||||
global TRANSFER_TYPE
|
||||
|
||||
if TRANSFER_TYPE is None:
|
||||
TRANSFER_TYPE = transfer_mode
|
||||
logger.info("SET TRANSFER TYPE TO %s", TRANSFER_TYPE)
|
||||
elif transfer_mode != TRANSFER_TYPE:
|
||||
raise ValueError(f"mismatched transfer mode {TRANSFER_TYPE} vs {transfer_mode}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
_list_lock = threading.RLock()
|
||||
|
||||
|
||||
def _listen_for_register(hostname, port):
|
||||
context = zmq.Context()
|
||||
router_socket = context.socket(zmq.ROUTER)
|
||||
router_socket.bind(f"tcp://{hostname}:{port}")
|
||||
poller = zmq.Poller()
|
||||
poller.register(router_socket, zmq.POLLIN)
|
||||
global prefill_instances
|
||||
global decode_instances
|
||||
|
||||
while True:
|
||||
socks = dict(poller.poll())
|
||||
if router_socket in socks:
|
||||
remote_addr, msg = router_socket.recv_multipart()
|
||||
data = msgpack.loads(msg)
|
||||
if data["type"] == "HELLO":
|
||||
pass
|
||||
elif (
|
||||
data["type"] == "register"
|
||||
and data["role"] == "P"
|
||||
and data["request_address"] not in prefill_instances
|
||||
):
|
||||
with _list_lock:
|
||||
_append_whole_dict_unique(prefill_instances, data)
|
||||
|
||||
elif (
|
||||
data["type"] == "register"
|
||||
and data["role"] == "D"
|
||||
and data["request_address"] not in decode_instances
|
||||
):
|
||||
with _list_lock:
|
||||
_append_whole_dict_unique(decode_instances, data)
|
||||
|
||||
|
||||
def start_service_discovery(hostname, port):
|
||||
if not hostname:
|
||||
hostname = socket.gethostname()
|
||||
if port == 0:
|
||||
raise ValueError("Port cannot be 0")
|
||||
|
||||
_listener_thread = threading.Thread(
|
||||
target=_listen_for_register, args=(hostname, port), daemon=True
|
||||
)
|
||||
_listener_thread.start()
|
||||
return _listener_thread
|
||||
|
||||
|
||||
async def send_request_to_prefill(
|
||||
endpoint, req_data, request_id, p_endpoint, pip, pports, selected_prefill_dp_rank
|
||||
):
|
||||
req_data_copy = req_data
|
||||
|
||||
req_data_copy["kv_transfer_params"].update(
|
||||
{
|
||||
"do_remote_decode": True,
|
||||
"do_remote_prefill": False,
|
||||
"remote_handshake_port": p_endpoint["handshake_port"],
|
||||
"remote_notify_port": p_endpoint["notify_port"],
|
||||
"remote_engine_id": None,
|
||||
"remote_block_ids": None,
|
||||
"remote_host": pip,
|
||||
"remote_port": pports,
|
||||
}
|
||||
)
|
||||
req_data_copy["stream"] = False
|
||||
req_data_copy["max_tokens"] = 1
|
||||
if "max_completion_tokens" in req_data_copy:
|
||||
req_data_copy["max_completion_tokens"] = 1
|
||||
if "stream_options" in req_data_copy:
|
||||
del req_data_copy["stream_options"]
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000)
|
||||
) as session:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
"X-Request-Id": request_id,
|
||||
}
|
||||
if selected_prefill_dp_rank is not None:
|
||||
headers["X-data-parallel-rank"] = str(selected_prefill_dp_rank)
|
||||
async with session.post(
|
||||
url=endpoint, json=req_data_copy, headers=headers
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"send_request_to_prefill response.status != 200response.status = ",
|
||||
response.status,
|
||||
)
|
||||
|
||||
|
||||
async def start_decode_request(endpoint, req_data, request_id):
|
||||
session = aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000)
|
||||
)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
"X-Request-Id": request_id,
|
||||
}
|
||||
response = await session.post(url=endpoint, json=req_data, headers=headers)
|
||||
return session, response
|
||||
|
||||
|
||||
async def stream_decode_response(session, response, request_id):
|
||||
try:
|
||||
if response.status == 200:
|
||||
async for chunk_bytes in response.content.iter_chunked(1024):
|
||||
yield chunk_bytes
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"decode response.status != 200, status = {response.status}"
|
||||
)
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
async def send_request_to_decode(endpoint, req_data, request_id):
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000)
|
||||
) as session:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
"X-Request-Id": request_id,
|
||||
}
|
||||
async with session.post(
|
||||
url=endpoint, json=req_data, headers=headers
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
async for chunk_bytes in response.content.iter_chunked(1024):
|
||||
yield chunk_bytes
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"send_request_to_decode response.status != 200,response.statuus = ",
|
||||
response.status,
|
||||
)
|
||||
|
||||
|
||||
def example_round_robin_dp_loader(request_number, dp_size):
|
||||
return request_nums % dp_size
|
||||
|
||||
|
||||
@app.route("/v1/completions", methods=["POST"])
|
||||
@app.route("/v1/chat/completions", methods=["POST"])
|
||||
async def handle_request():
|
||||
try:
|
||||
with _list_lock:
|
||||
global request_nums
|
||||
request_nums += 1
|
||||
|
||||
def extract_ip_port_fast(url):
|
||||
match = IP_PORT_PATTERN.search(url)
|
||||
if not match:
|
||||
raise ValueError(f"Invalid URL format: {url}")
|
||||
return match.groups()
|
||||
|
||||
req_data = await request.get_json()
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
prefill_instance_endpoint = None
|
||||
decode_instance_endpoint = None
|
||||
error_msg = (
|
||||
"Service Unavailable: No prefill or decode instances are registered."
|
||||
)
|
||||
if not prefill_instances or not decode_instances:
|
||||
return await make_response(
|
||||
(
|
||||
error_msg,
|
||||
503,
|
||||
)
|
||||
)
|
||||
pid = request_nums % len(prefill_instances)
|
||||
did = request_nums % len(decode_instances)
|
||||
prefill_instance_endpoint = prefill_instances[pid]
|
||||
decode_instance_endpoint = decode_instances[did]
|
||||
|
||||
selected_prefill_dp_rank = None
|
||||
if prefill_instance_endpoint["dp_size"] > 1:
|
||||
selected_prefill_dp_rank = example_round_robin_dp_loader(
|
||||
request_nums // len(prefill_instance_endpoint),
|
||||
prefill_instance_endpoint["dp_size"],
|
||||
)
|
||||
|
||||
dip, dport = extract_ip_port_fast(decode_instance_endpoint["request_address"])
|
||||
ip, port = extract_ip_port_fast(prefill_instance_endpoint["request_address"])
|
||||
|
||||
req_data_to_prefill = copy.deepcopy(req_data)
|
||||
req_data_to_prefill["kv_transfer_params"] = {}
|
||||
req_data["kv_transfer_params"] = {}
|
||||
req_data_to_prefill["kv_transfer_params"]["remote_dp_size"] = (
|
||||
decode_instance_endpoint["dp_size"]
|
||||
)
|
||||
req_data_to_prefill["kv_transfer_params"]["remote_tp_size"] = (
|
||||
decode_instance_endpoint["tp_size"]
|
||||
)
|
||||
|
||||
send_prefill_task = asyncio.create_task(
|
||||
send_request_to_prefill(
|
||||
prefill_instance_endpoint["request_address"],
|
||||
req_data_to_prefill,
|
||||
request_id,
|
||||
decode_instance_endpoint,
|
||||
dip,
|
||||
dport,
|
||||
selected_prefill_dp_rank,
|
||||
)
|
||||
)
|
||||
ip, port = extract_ip_port_fast(prefill_instance_endpoint["request_address"])
|
||||
|
||||
req_data["max_tokens"] -= 1
|
||||
|
||||
req_data["kv_transfer_params"] = {
|
||||
"do_remote_decode": False,
|
||||
"do_remote_prefill": True,
|
||||
"remote_handshake_port": prefill_instance_endpoint["handshake_port"],
|
||||
"remote_notify_port": prefill_instance_endpoint["notify_port"],
|
||||
"remote_engine_id": None,
|
||||
"remote_block_ids": None,
|
||||
"remote_host": ip,
|
||||
"remote_port": port,
|
||||
}
|
||||
if TRANSFER_TYPE == "READ":
|
||||
# In read mode, prefill and decode are executed serially.
|
||||
prefill_response = await send_prefill_task
|
||||
req_data["kv_transfer_params"]["remote_engine_id"] = prefill_response[
|
||||
"kv_transfer_params"
|
||||
]["remote_engine_id"]
|
||||
req_data["kv_transfer_params"]["remote_block_ids"] = prefill_response[
|
||||
"kv_transfer_params"
|
||||
]["remote_block_ids"]
|
||||
|
||||
req_data["kv_transfer_params"]["remote_dp_size"] = prefill_instance_endpoint[
|
||||
"dp_size"
|
||||
]
|
||||
req_data["kv_transfer_params"]["remote_tp_size"] = prefill_instance_endpoint[
|
||||
"tp_size"
|
||||
]
|
||||
|
||||
if selected_prefill_dp_rank is not None:
|
||||
req_data["kv_transfer_params"]["remote_dp_rank"] = selected_prefill_dp_rank
|
||||
|
||||
decode_request_task = asyncio.create_task(
|
||||
start_decode_request(
|
||||
decode_instance_endpoint["request_address"], req_data, request_id
|
||||
)
|
||||
)
|
||||
|
||||
session, decode_response = await decode_request_task
|
||||
stream_generator = stream_decode_response(session, decode_response, request_id)
|
||||
response = await make_response(stream_generator)
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.exception("An error occurred while handling the request: %s", e)
|
||||
return await make_response(
|
||||
(
|
||||
f"Internal Server Error: {e!s}",
|
||||
500,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
t = start_service_discovery("0.0.0.0", 36367)
|
||||
app.debug = True
|
||||
app.config["BODY_TIMEOUT"] = 360000
|
||||
app.config["RESPONSE_TIMEOUT"] = 360000
|
||||
|
||||
app.run(host="0.0.0.0", port=10001)
|
||||
t.join()
|
||||
545
tests/v1/kv_connector/unit/test_moriio_connector.py
Normal file
545
tests/v1/kv_connector/unit/test_moriio_connector.py
Normal file
@ -0,0 +1,545 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import importlib.util
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import msgspec
|
||||
import pytest
|
||||
import torch
|
||||
import zmq
|
||||
|
||||
from tests.conftest import _find_free_port
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
DeviceConfig,
|
||||
KVTransferConfig,
|
||||
ModelConfig,
|
||||
SchedulerConfig,
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import (
|
||||
MoRIIOAgentMetadata,
|
||||
MoRIIOConnectorMetadata,
|
||||
MoRIIOConstants,
|
||||
zmq_ctx,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector import (
|
||||
KVConnectorRole,
|
||||
MoRIIOConnector,
|
||||
MoRIIOConnectorWorker,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.network_utils import (
|
||||
get_ip,
|
||||
make_zmq_path,
|
||||
)
|
||||
|
||||
from .utils import create_request, create_scheduler
|
||||
|
||||
aiter_available = importlib.util.find_spec("aiter") is not None
|
||||
mori_available = importlib.util.find_spec("mori") is not None
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not (current_platform.is_rocm() and mori_available),
|
||||
reason="MoRIIOs are only available on ROCm with aiter package installed",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_parallel_groups():
|
||||
"""Mock tensor/data parallel group functions for single-rank tests."""
|
||||
mock_group = MagicMock()
|
||||
mock_group.rank = 0
|
||||
mock_group.local_rank = 0
|
||||
mock_group.world_size = 1
|
||||
|
||||
with (
|
||||
patch.multiple(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common",
|
||||
get_tensor_model_parallel_rank=MagicMock(return_value=0),
|
||||
get_tensor_model_parallel_world_size=MagicMock(return_value=0),
|
||||
),
|
||||
patch.multiple(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector",
|
||||
get_tensor_model_parallel_world_size=MagicMock(return_value=0),
|
||||
get_world_group=MagicMock(return_value=mock_group),
|
||||
get_tp_group=MagicMock(return_value=mock_group),
|
||||
),
|
||||
):
|
||||
yield mock_group
|
||||
|
||||
|
||||
def _setup_kv_transfer_request(request, remote_host="127.0.0.1", fake_port=4789):
|
||||
"""Setup KV transfer parameters for a request."""
|
||||
request.kv_transfer_params.update(
|
||||
{
|
||||
"remote_notify_port": fake_port,
|
||||
"remote_block_ids": None,
|
||||
"remote_host": remote_host,
|
||||
"remote_port": fake_port,
|
||||
"remote_handshake_port": fake_port,
|
||||
"remote_engine_id": "test_engine",
|
||||
}
|
||||
)
|
||||
return request
|
||||
|
||||
|
||||
class FakeMorIIOWrapper:
|
||||
# A fake MoRIIOWrapper for testing purposes
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def set_moriio_engine(self, moriio_engine):
|
||||
pass
|
||||
|
||||
def set_backend_type(self, backend_type):
|
||||
pass
|
||||
|
||||
def get_agent_metadata(self):
|
||||
pass
|
||||
|
||||
def register_remote_engine(self, remote_packed_engine_metadata):
|
||||
pass
|
||||
|
||||
def register_local_tensor(self, tensor: torch.Tensor):
|
||||
pass
|
||||
|
||||
def get_unpack_memory_metadata(self, packed_memory_metadata):
|
||||
pass
|
||||
|
||||
def build_session(self, local_memory_metadata, remote_memory_metadata):
|
||||
pass
|
||||
|
||||
def read_remote_data(
|
||||
self, transfer_size_byte, local_offset=0, remote_offset=0, session=None
|
||||
):
|
||||
pass
|
||||
|
||||
def write_remote_data(
|
||||
self, transfer_size_byte, local_offset=0, remote_offset=0, session=None
|
||||
):
|
||||
pass
|
||||
|
||||
def write_remote_data_single(
|
||||
self, transfer_size_byte, local_offset=0, remote_offset=0, sess_idx=0
|
||||
):
|
||||
pass
|
||||
|
||||
def waiting_for_transfer_complete(self):
|
||||
pass
|
||||
|
||||
def async_wait_reqid(self):
|
||||
pass
|
||||
|
||||
def _handle_message(self, msg: bytes):
|
||||
pass
|
||||
|
||||
def _handle_structured_message(self, data: dict):
|
||||
pass
|
||||
|
||||
def _handle_completion_message(self, msg: str):
|
||||
pass
|
||||
|
||||
def send_notify(self, req_ids, remote_ip, remote_port):
|
||||
pass
|
||||
|
||||
def pop_finished_req_ids(self):
|
||||
pass
|
||||
|
||||
def pop_finished_write_req_ids(self):
|
||||
pass
|
||||
|
||||
def shutdown(self):
|
||||
pass
|
||||
|
||||
|
||||
class FakeMorIIOConnectorWorker(MoRIIOConnectorWorker):
|
||||
# Define a fake remote engine id for testing
|
||||
REMOTE_ENGINE_ID = "remote_engine"
|
||||
|
||||
def __init__(
|
||||
self, *args, hand_shake_latency: float = 1.8, kv_cache_layout="HND", **kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
def create_vllm_config(
|
||||
model: str = "facebook/opt-125m",
|
||||
max_num_seqs: int = 16,
|
||||
max_num_batched_tokens: int = 64,
|
||||
block_size: int = 16,
|
||||
max_model_len: int = 10000,
|
||||
enable_chunked_prefill: bool = True,
|
||||
enable_permute_local_kv: bool = False,
|
||||
role="kv_consumer",
|
||||
) -> VllmConfig:
|
||||
"""Initialize VllmConfig for testing."""
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_model_len=max_model_len,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
is_encoder_decoder=False,
|
||||
)
|
||||
model_config = ModelConfig(
|
||||
model=model,
|
||||
trust_remote_code=True,
|
||||
dtype="bfloat16",
|
||||
seed=42,
|
||||
)
|
||||
# Cache config, optionally force APC
|
||||
cache_config = CacheConfig(
|
||||
block_size=block_size,
|
||||
gpu_memory_utilization=0.9,
|
||||
swap_space=0,
|
||||
cache_dtype="auto",
|
||||
enable_prefix_caching=True,
|
||||
)
|
||||
kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="MoRIIOConnector",
|
||||
kv_role=role,
|
||||
enable_permute_local_kv=enable_permute_local_kv,
|
||||
)
|
||||
return VllmConfig(
|
||||
scheduler_config=scheduler_config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
device_config=DeviceConfig("cpu"),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def moriio_read_mode():
|
||||
"""Force the connector into read mode via env for tests."""
|
||||
os.environ["VLLM_MORIIO_CONNECTOR_READ_MODE"] = "True"
|
||||
yield
|
||||
# Cleanup after test
|
||||
os.environ.pop("VLLM_MORIIO_CONNECTOR_READ_MODE", None)
|
||||
|
||||
|
||||
def test_write_mode_saves_local_block_ids():
|
||||
"""Write mode records local block ids in MoRIIOConnectorMetadata.reqs_to_save."""
|
||||
|
||||
# Setup Scheduler and Request
|
||||
vllm_config = create_vllm_config(role="kv_producer")
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# 2 Full Blocks and 1 Half Block.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request = create_request(
|
||||
request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_decode=True,
|
||||
do_remote_prefill=False,
|
||||
)
|
||||
request_id = request.request_id
|
||||
|
||||
scheduler.add_request(request)
|
||||
|
||||
# Fake Config
|
||||
request = _setup_kv_transfer_request(request)
|
||||
|
||||
# Remote Prefill, triggers MoRIIOConnectorMetadata.
|
||||
scheduler_output = scheduler.schedule()
|
||||
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
||||
assert kv_connector_metadata is not None, "kv_connector_metadata is None"
|
||||
assert isinstance(kv_connector_metadata, MoRIIOConnectorMetadata)
|
||||
|
||||
assert len(kv_connector_metadata.reqs_to_save) == 1, (
|
||||
"Unexpected number of reqs_to_save"
|
||||
)
|
||||
assert len(kv_connector_metadata.reqs_to_recv) == 0, (
|
||||
"Unexpected number of reqs_to_recv"
|
||||
)
|
||||
assert len(kv_connector_metadata.reqs_to_send) == 0, (
|
||||
"Unexpected number of reqs_to_send"
|
||||
)
|
||||
assert request_id in kv_connector_metadata.reqs_to_save, (
|
||||
"Request ID not in reqs_to_save"
|
||||
)
|
||||
req_meta = kv_connector_metadata.reqs_to_save[request_id]
|
||||
|
||||
for block_id, block in zip(
|
||||
req_meta.local_block_ids,
|
||||
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[
|
||||
request_id
|
||||
],
|
||||
):
|
||||
assert block_id == block.block_id, f"{block_id} != {block.block_id}"
|
||||
|
||||
|
||||
def test_write_mode_with_chunked_prefill_saves_local_block_ids():
|
||||
"""Write mode with chunked prefill still records correct local block ids."""
|
||||
# Setup Scheduler and Request
|
||||
MAX_NUM_BATCHED_TOKENS = 64
|
||||
NUM_TOKENS = MAX_NUM_BATCHED_TOKENS * 2 + MAX_NUM_BATCHED_TOKENS // 2
|
||||
|
||||
vllm_config = create_vllm_config(
|
||||
max_num_batched_tokens=MAX_NUM_BATCHED_TOKENS, role="kv_producer"
|
||||
)
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# 2 Full Blocks and 1 Half Block.
|
||||
|
||||
request = create_request(
|
||||
request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_decode=True,
|
||||
do_remote_prefill=False,
|
||||
)
|
||||
request_id = request.request_id
|
||||
|
||||
scheduler.add_request(request)
|
||||
|
||||
# Fake Config
|
||||
request = _setup_kv_transfer_request(request)
|
||||
|
||||
# Remote Prefill with chunked prefill, triggers multiple schedules.
|
||||
expected_counts = [(0, 0, 0), (0, 0, 0), (1, 0, 0)]
|
||||
kv_connector_metadata = None
|
||||
for _, (expected_save, expected_recv, expected_send) in enumerate(expected_counts):
|
||||
scheduler_output = scheduler.schedule()
|
||||
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
||||
|
||||
assert len(kv_connector_metadata.reqs_to_save) == expected_save
|
||||
assert len(kv_connector_metadata.reqs_to_recv) == expected_recv
|
||||
assert len(kv_connector_metadata.reqs_to_send) == expected_send
|
||||
assert kv_connector_metadata is not None, "kv_connector_metadata is None"
|
||||
assert request_id in kv_connector_metadata.reqs_to_save, (
|
||||
"Request ID not in reqs_to_save"
|
||||
)
|
||||
req_meta = kv_connector_metadata.reqs_to_save[request_id]
|
||||
|
||||
for block_id, block in zip(
|
||||
req_meta.local_block_ids,
|
||||
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[
|
||||
request_id
|
||||
],
|
||||
):
|
||||
assert block_id == block.block_id, f"{block_id} != {block.block_id}"
|
||||
|
||||
|
||||
def test_read_mode_loads_remote_block_ids(moriio_read_mode):
|
||||
"""Read mode loads remote block ids into local cache mapping."""
|
||||
|
||||
# Setup Scheduler and Request
|
||||
vllm_config = create_vllm_config(role="kv_consumer")
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# 2 Full Blocks and 1 Half Block.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request = create_request(
|
||||
request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_decode=False,
|
||||
do_remote_prefill=True,
|
||||
)
|
||||
request_id = request.request_id
|
||||
|
||||
scheduler.add_request(request)
|
||||
block_list = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0
|
||||
].req_to_blocks[request_id]
|
||||
|
||||
request = _setup_kv_transfer_request(request)
|
||||
|
||||
# Set remote block ids to be fetched.
|
||||
request.kv_transfer_params["remote_block_ids"] = block_list
|
||||
|
||||
# Remote Prefill, triggers MorIIOConnectorMetadata.
|
||||
|
||||
scheduler_output = scheduler.schedule()
|
||||
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
||||
assert kv_connector_metadata is not None, "kv_connector_metadata is None"
|
||||
assert isinstance(kv_connector_metadata, MoRIIOConnectorMetadata), (
|
||||
"kv_connector_metadata is not MoRIIOConnectorMetadata"
|
||||
)
|
||||
assert len(kv_connector_metadata.reqs_to_save) == 0, (
|
||||
"Unexpected number of reqs_to_save"
|
||||
)
|
||||
assert len(kv_connector_metadata.reqs_to_recv) == 1, (
|
||||
"Unexpected number of reqs_to_recv"
|
||||
)
|
||||
assert len(kv_connector_metadata.reqs_to_send) == 0, (
|
||||
"Unexpected number of reqs_to_send"
|
||||
)
|
||||
assert request_id in kv_connector_metadata.reqs_to_recv, (
|
||||
"Request ID not in reqs_to_recv"
|
||||
)
|
||||
req_meta = kv_connector_metadata.reqs_to_recv[request_id]
|
||||
|
||||
for block_id, block in zip(
|
||||
req_meta.local_block_ids,
|
||||
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[
|
||||
request_id
|
||||
],
|
||||
):
|
||||
assert block_id == block.block_id, f"{block_id} != {block.block_id}"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not aiter_available, reason="Requires aiter package for ROCm FlashAttention backend"
|
||||
)
|
||||
def test_register_kv_caches(mock_parallel_groups):
|
||||
"""Test that MoRIIOConnector.register_kv_caches correctly registers kv caches."""
|
||||
ROLE = "kv_consumer"
|
||||
IP = get_ip()
|
||||
vllm_config = create_vllm_config(role=ROLE)
|
||||
DEFAULT_PORT = 6301
|
||||
TP_RANK = 0
|
||||
DP_RANK = 0
|
||||
from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend
|
||||
|
||||
backend_cls = AiterFlashAttentionBackend
|
||||
|
||||
# Create test kv cache tensors using proper backend shape
|
||||
kv_cache_shape = backend_cls.get_kv_cache_shape(
|
||||
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
|
||||
)
|
||||
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
|
||||
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
|
||||
kv_caches = {
|
||||
"layer0": shared_tensor,
|
||||
"layer1": unique_tensor,
|
||||
"layer2": shared_tensor,
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.threading.Event"
|
||||
),
|
||||
patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector.threading.Thread"
|
||||
),
|
||||
):
|
||||
# Create connector
|
||||
vllm_config.kv_transfer_config.kv_connector_extra_config.update(
|
||||
{
|
||||
"proxy_ip": "127.0.0.1",
|
||||
"proxy_ping_port": 12345,
|
||||
"http_port": 12346,
|
||||
}
|
||||
)
|
||||
|
||||
connector = MoRIIOConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector.connector_worker = FakeMorIIOConnectorWorker(
|
||||
vllm_config, connector.engine_id, hand_shake_latency=0
|
||||
)
|
||||
|
||||
from mori.io import (
|
||||
MemoryDesc,
|
||||
)
|
||||
|
||||
# Execute register_kv_caches
|
||||
connector.register_kv_caches(kv_caches)
|
||||
|
||||
# Verify that the MemoryDesc stored in layer_name_to_local_kv_cache_metadata
|
||||
assert (
|
||||
shared_tensor.data_ptr()
|
||||
== MemoryDesc.unpack(
|
||||
connector.connector_worker.layer_name_to_local_kv_cache_metadata[
|
||||
"layer0"
|
||||
][0]
|
||||
).data
|
||||
)
|
||||
assert (
|
||||
unique_tensor.data_ptr()
|
||||
== MemoryDesc.unpack(
|
||||
connector.connector_worker.layer_name_to_local_kv_cache_metadata[
|
||||
"layer1"
|
||||
][0]
|
||||
).data
|
||||
)
|
||||
assert (
|
||||
shared_tensor.data_ptr()
|
||||
== MemoryDesc.unpack(
|
||||
connector.connector_worker.layer_name_to_local_kv_cache_metadata[
|
||||
"layer2"
|
||||
][0]
|
||||
).data
|
||||
)
|
||||
|
||||
# Verify engine keys
|
||||
expected_engine_key = f"{ROLE[3:]}:{IP}:{DEFAULT_PORT}:tp{TP_RANK}:dp{DP_RANK}"
|
||||
assert (
|
||||
MemoryDesc.unpack(
|
||||
connector.connector_worker.layer_name_to_local_kv_cache_metadata[
|
||||
"layer0"
|
||||
][0]
|
||||
).engine_key
|
||||
== expected_engine_key
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not aiter_available, reason="Requires aiter package for ROCm FlashAttention backend"
|
||||
)
|
||||
def test_moriio_handshake_returns_metadata(mock_parallel_groups):
|
||||
"""MoRIIO handshake socket returns valid agent metadata over ZMQ."""
|
||||
|
||||
ROLE = "kv_consumer"
|
||||
vllm_config = create_vllm_config(role=ROLE)
|
||||
from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend
|
||||
|
||||
backend_cls = AiterFlashAttentionBackend
|
||||
|
||||
# Create test kv cache tensors using proper backend shape
|
||||
kv_cache_shape = backend_cls.get_kv_cache_shape(
|
||||
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
|
||||
)
|
||||
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
|
||||
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
|
||||
kv_caches = {
|
||||
"layer0": shared_tensor,
|
||||
"layer1": unique_tensor,
|
||||
"layer2": shared_tensor,
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_engine.MoRIIOWrapper",
|
||||
FakeMorIIOWrapper,
|
||||
),
|
||||
):
|
||||
handshake_port = _find_free_port()
|
||||
# Create connector
|
||||
vllm_config.kv_transfer_config.kv_connector_extra_config.update(
|
||||
{
|
||||
"proxy_ip": "127.0.0.1",
|
||||
"proxy_ping_port": 12345,
|
||||
"http_port": 12346,
|
||||
"handshake_port": handshake_port,
|
||||
}
|
||||
)
|
||||
connector = MoRIIOConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
|
||||
# Execute register_kv_caches
|
||||
connector.register_kv_caches(kv_caches)
|
||||
|
||||
# Connect to handshake socket and request metadata
|
||||
path = make_zmq_path("tcp", "127.0.0.1", handshake_port)
|
||||
with zmq_ctx(zmq.DEALER, path) as sock:
|
||||
sock.send(MoRIIOConstants.GET_META_MSG)
|
||||
received_frame = sock.recv_multipart()
|
||||
|
||||
if len(received_frame) != 2 or received_frame[0] != b"":
|
||||
raise ValueError(f"Unexpected frame! {received_frame = }")
|
||||
|
||||
metadata_bytes = received_frame[1]
|
||||
decoder = msgspec.msgpack.Decoder(MoRIIOAgentMetadata)
|
||||
metadata = decoder.decode(metadata_bytes)
|
||||
assert isinstance(metadata, MoRIIOAgentMetadata), (
|
||||
"Decoded metadata is not MoRIIOAgentMetadata"
|
||||
)
|
||||
@ -179,6 +179,12 @@ KVConnectorFactory.register_connector(
|
||||
"MultiConnector",
|
||||
)
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"MoRIIOConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector",
|
||||
"MoRIIOConnector",
|
||||
)
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"OffloadingConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector",
|
||||
|
||||
@ -0,0 +1,321 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import msgspec
|
||||
import torch
|
||||
import zmq
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorMetadata,
|
||||
)
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.network_utils import (
|
||||
get_ip,
|
||||
get_open_port,
|
||||
make_zmq_socket,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
from dataclasses import field
|
||||
from enum import Enum
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
Transfer = tuple[int, float]
|
||||
EngineId = str
|
||||
ReqId = str
|
||||
|
||||
|
||||
@dataclass
|
||||
class WriteTask:
|
||||
request_id: str
|
||||
dst_engine_id: str
|
||||
local_block_ids: list[int]
|
||||
remote_block_ids_hint: list[int] | None
|
||||
layer_name: str
|
||||
event: torch.cuda.Event
|
||||
remote_notify_port: int
|
||||
remote_ip: str
|
||||
enqueue_time: float = field(default_factory=time.perf_counter)
|
||||
retried: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayerTransferPlan:
|
||||
"""Plan for transferring a single layer."""
|
||||
|
||||
request_id: str
|
||||
layer_name: str
|
||||
sess_idx: int
|
||||
transfer_local_offsets: list[int]
|
||||
transfer_remote_offsets: list[int]
|
||||
transfer_sizes: list[int]
|
||||
use_batch: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class RemoteAllocInfo:
|
||||
"""Information about remote block allocation."""
|
||||
|
||||
block_ids: list[int]
|
||||
writes_done: int = 0
|
||||
decode_dp_rank: int = 0
|
||||
transfer_offset: tuple[list[int], list[int], list[int]] | None = None
|
||||
|
||||
|
||||
class ROLE(Enum):
|
||||
PRODUCER = "producer"
|
||||
CONSUMER = "consumer"
|
||||
NOTINIT = "notinit"
|
||||
|
||||
|
||||
class MoRIIOAgentMetadata(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
# required for @cached_property.d
|
||||
dict=True,
|
||||
):
|
||||
engine_id: str
|
||||
agent_metadata: bytes
|
||||
kv_caches_base_addr: list[int]
|
||||
num_blocks: int
|
||||
block_len: int
|
||||
attn_backend_name: str
|
||||
|
||||
|
||||
class RoleManager:
|
||||
"""Manages role state across the connector."""
|
||||
|
||||
_instance: Optional["RoleManager"] = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._role: ROLE = ROLE.NOTINIT
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "RoleManager":
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
def set_role(self, role: ROLE) -> None:
|
||||
"""Set the current role."""
|
||||
with self._lock:
|
||||
self._role = role
|
||||
|
||||
def get_role(self) -> ROLE:
|
||||
"""Get the current role."""
|
||||
return self._role
|
||||
|
||||
|
||||
def set_role(role: ROLE):
|
||||
"""Set the global role."""
|
||||
RoleManager.get_instance().set_role(role)
|
||||
|
||||
|
||||
def get_role() -> ROLE:
|
||||
"""Get the global role."""
|
||||
return RoleManager.get_instance().get_role()
|
||||
|
||||
|
||||
class MoRIIOMode(Enum):
|
||||
READ = "read"
|
||||
WRITE = "write"
|
||||
|
||||
|
||||
class MoRIIOError(Exception):
|
||||
"""Base exception for MoRIIO operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class HandshakeError(MoRIIOError):
|
||||
"""Exception raised when handshake fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TransferError(MoRIIOError):
|
||||
"""Exception raised when transfer fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def get_moriio_mode() -> MoRIIOMode:
|
||||
read_mode = envs.VLLM_MORIIO_CONNECTOR_READ_MODE
|
||||
logger.debug("MoRIIO Connector read_mode: %s", read_mode)
|
||||
if read_mode:
|
||||
return MoRIIOMode.READ
|
||||
else:
|
||||
return MoRIIOMode.WRITE
|
||||
|
||||
|
||||
def get_port_offset(dp_rank: int, tp_rank: int, tp_size: int = 1) -> int:
|
||||
return (dp_rank) * tp_size + tp_rank
|
||||
|
||||
|
||||
@dataclass
|
||||
class MoRIIOConfig:
|
||||
local_ip: str
|
||||
local_kv_port: int
|
||||
proxy_ip: str
|
||||
local_ping_port: int
|
||||
proxy_ping_port: int
|
||||
http_port: int
|
||||
handshake_port: int
|
||||
notify_port: int
|
||||
tp_rank: int
|
||||
dp_rank: int
|
||||
dp_size: int
|
||||
tp_size: int
|
||||
|
||||
@classmethod
|
||||
def from_vllm_config(cls, vllm_config: VllmConfig) -> "MoRIIOConfig":
|
||||
# Port Configuration:
|
||||
# local_ping_port -> Outgoing heartbeat to proxy
|
||||
# proxy_ping_port -> Remote proxy's heartbeat ingress port
|
||||
# http_port -> Instance's HTTP service endpoint
|
||||
# local_kv_port -> service port for mori engine
|
||||
# notify_port -> For synchronizing stages between prefill and decode
|
||||
# handshake_port -> For initial handshake between mori engine
|
||||
|
||||
# TODO : merge notify_port and handshake_port to simplify port management
|
||||
# supports non-contiguous ports
|
||||
assert vllm_config.kv_transfer_config is not None, (
|
||||
"kv_transfer_config must be set for MoRIIOConnector"
|
||||
)
|
||||
kv_transfer_config = vllm_config.kv_transfer_config
|
||||
extra_config = kv_transfer_config.kv_connector_extra_config
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
base_notify_port = int(extra_config["notify_port"])
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
port_offset = get_port_offset(dp_rank, tp_rank)
|
||||
|
||||
return cls(
|
||||
local_ip=get_ip(),
|
||||
local_kv_port=get_open_port(),
|
||||
proxy_ip=extra_config["proxy_ip"],
|
||||
local_ping_port=get_open_port(),
|
||||
proxy_ping_port=int(extra_config["proxy_ping_port"]),
|
||||
http_port=int(extra_config["http_port"]),
|
||||
handshake_port=int(extra_config["handshake_port"]),
|
||||
notify_port=base_notify_port + port_offset,
|
||||
tp_rank=tp_rank,
|
||||
dp_rank=dp_rank,
|
||||
dp_size=dp_size,
|
||||
tp_size=tp_size,
|
||||
)
|
||||
|
||||
|
||||
class MoRIIOConstants:
|
||||
"""Constants for MoRIIO connector."""
|
||||
|
||||
# ZMQ message types
|
||||
GET_META_MSG = b"get_meta_msg"
|
||||
POP_DONE_RECV = b"pop_done_recv"
|
||||
OVER = b"OVER"
|
||||
COMPLETION_PREFIX = "cmpl"
|
||||
|
||||
PING_INTERVAL = 5
|
||||
MAX_PING_RETRIES = 100
|
||||
DEFAULT_HANDSHAKE_PORT = "6301"
|
||||
DEFAULT_NOTIFY_PORT = "61005"
|
||||
|
||||
VLLM_MORI_READ_ABORT_REQUEST_TIMEOUT = 3600
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
"""Metadata for a single request."""
|
||||
|
||||
local_block_ids: list[int]
|
||||
remote_block_ids: list[int]
|
||||
remote_host: str
|
||||
remote_port: int
|
||||
remote_handshake_port: int
|
||||
remote_notify_port: int
|
||||
remote_engine_id: str
|
||||
tp_size: int
|
||||
remote_dp_size: int
|
||||
|
||||
|
||||
class MoRIIOConnectorMetadata(KVConnectorMetadata):
|
||||
def __init__(self):
|
||||
self.reqs_to_recv: dict[ReqId, ReqMeta] = {}
|
||||
self.reqs_to_save: dict[ReqId, ReqMeta] = {}
|
||||
self.reqs_to_send: dict[ReqId, float] = {}
|
||||
|
||||
def __repr__(self):
|
||||
return_str = ""
|
||||
for req_id, req_meta in self.reqs_to_recv.items():
|
||||
return_str += (
|
||||
f"{req_id = },{req_meta.local_block_ids = },"
|
||||
f"{req_meta.remote_host = },{req_meta.remote_port = }"
|
||||
f"{req_meta.remote_engine_id = },{req_meta.tp_size = }"
|
||||
)
|
||||
return_str = f"MoRIIOConnectorMetadata:reqs_to_recv:{return_str},"
|
||||
|
||||
for req_id, expiry in self.reqs_to_send.items():
|
||||
return_str += f"{req_id = },{expiry = }"
|
||||
return_str = f"MoRIIOConnectorMetadata:reqs_to_send:{return_str},"
|
||||
return return_str
|
||||
|
||||
def add_new_req(
|
||||
self,
|
||||
request_id: ReqId,
|
||||
local_block_ids: list[int],
|
||||
kv_transfer_params: dict[str, Any],
|
||||
write_mode=False,
|
||||
):
|
||||
_req = ReqMeta(
|
||||
local_block_ids=local_block_ids,
|
||||
remote_block_ids=kv_transfer_params["remote_block_ids"],
|
||||
remote_engine_id=kv_transfer_params["remote_engine_id"],
|
||||
remote_host=kv_transfer_params["remote_host"],
|
||||
remote_port=kv_transfer_params["remote_port"],
|
||||
remote_handshake_port=kv_transfer_params["remote_handshake_port"],
|
||||
remote_notify_port=kv_transfer_params["remote_notify_port"],
|
||||
tp_size=kv_transfer_params.get("tp_size", 1),
|
||||
remote_dp_size=kv_transfer_params.get("remote_dp_size", 1),
|
||||
)
|
||||
if write_mode:
|
||||
self.reqs_to_save[request_id] = _req
|
||||
else:
|
||||
self.reqs_to_recv[request_id] = _req
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]:
|
||||
"""Context manager for a ZMQ socket"""
|
||||
|
||||
if socket_type not in (zmq.ROUTER, zmq.REQ, zmq.DEALER):
|
||||
raise ValueError(f"Unexpected socket type: {socket_type}")
|
||||
|
||||
ctx: zmq.Context | None = None
|
||||
try:
|
||||
ctx = zmq.Context() # type: ignore[attr-defined]
|
||||
yield make_zmq_socket(
|
||||
ctx=ctx, path=addr, socket_type=socket_type, bind=socket_type == zmq.ROUTER
|
||||
)
|
||||
finally:
|
||||
if ctx is not None:
|
||||
ctx.destroy(linger=0)
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,609 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from weakref import ref as weakref_ref
|
||||
|
||||
import msgpack
|
||||
import torch
|
||||
import zmq
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.network_utils import (
|
||||
make_zmq_path,
|
||||
make_zmq_socket,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
from queue import Empty, Queue
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import (
|
||||
ROLE,
|
||||
HandshakeError,
|
||||
LayerTransferPlan,
|
||||
MoRIIOAgentMetadata,
|
||||
MoRIIOConstants,
|
||||
MoRIIOError,
|
||||
RemoteAllocInfo,
|
||||
TransferError,
|
||||
WriteTask,
|
||||
get_port_offset,
|
||||
get_role,
|
||||
zmq_ctx,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector import (
|
||||
MoRIIOConnectorWorker,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
try:
|
||||
from mori.io import (
|
||||
EngineDesc,
|
||||
IOEngine,
|
||||
MemoryDesc,
|
||||
PollCqMode,
|
||||
RdmaBackendConfig,
|
||||
)
|
||||
|
||||
logger.info("MoRIIO is available")
|
||||
except ImportError:
|
||||
logger.error("MoRIIO is not available")
|
||||
|
||||
|
||||
"""Write task execution logic for MoRIIO connector."""
|
||||
|
||||
|
||||
class MoRIIOWriter:
|
||||
"""Handles write operations for KV cache transfers.
|
||||
Implements distributed KV cache transfer using the MoRIIO library
|
||||
for RDMA-based communication between prefill and decode instances."""
|
||||
|
||||
def __init__(self, worker: "MoRIIOConnectorWorker"):
|
||||
"""Initialize the writer.
|
||||
|
||||
Args:
|
||||
worker: Reference to the parent worker
|
||||
"""
|
||||
self._worker_ref: weakref_ref[MoRIIOConnectorWorker] = weakref_ref(worker)
|
||||
self._write_task_q: Queue[WriteTask] = Queue()
|
||||
self._write_worker_started = False
|
||||
self._write_worker_lock = threading.Lock()
|
||||
self._deferred_tasks: list[WriteTask] = []
|
||||
|
||||
@property
|
||||
def worker(self) -> "MoRIIOConnectorWorker":
|
||||
"""Get the worker instance.
|
||||
|
||||
Returns:
|
||||
The parent worker instance
|
||||
|
||||
Raises:
|
||||
RuntimeError: If worker has been garbage collected
|
||||
"""
|
||||
worker = self._worker_ref()
|
||||
if worker is None:
|
||||
raise RuntimeError("Parent worker has been garbage collected")
|
||||
return worker
|
||||
|
||||
def ensure_worker_started(self) -> None:
|
||||
"""Ensure the background write worker is running."""
|
||||
if self._write_worker_started:
|
||||
return
|
||||
self._write_worker_started = True
|
||||
with self._write_worker_lock:
|
||||
thread = threading.Thread(
|
||||
target=self._write_worker_loop, daemon=True, name="moriio-write-worker"
|
||||
)
|
||||
thread.start()
|
||||
logger.info("Started MoRIIO write worker thread")
|
||||
|
||||
def schedule_write(self, task: WriteTask) -> None:
|
||||
"""Schedule a write task.
|
||||
|
||||
Args:
|
||||
task: The write task to schedule
|
||||
"""
|
||||
self.ensure_worker_started()
|
||||
self._write_task_q.put(task)
|
||||
|
||||
def _write_worker_loop(self) -> None:
|
||||
"""Main loop for the write worker thread."""
|
||||
|
||||
while True:
|
||||
# Process deferred tasks first
|
||||
self._process_deferred_tasks()
|
||||
|
||||
# Get new task
|
||||
try:
|
||||
task = self._write_task_q.get(timeout=0.01)
|
||||
except Empty:
|
||||
continue
|
||||
|
||||
# Check if remote blocks are ready
|
||||
if not self._is_remote_ready(task):
|
||||
# task.retry_count += 1
|
||||
self._deferred_tasks.append(task)
|
||||
# logger.debug(
|
||||
# "Deferred task for request %s (retry %d)",
|
||||
# task.request_id, task.retry_count
|
||||
# )
|
||||
continue
|
||||
|
||||
# Execute the task
|
||||
|
||||
self._execute_write_task(task)
|
||||
|
||||
def _process_deferred_tasks(self) -> None:
|
||||
"""Process tasks that were previously deferred."""
|
||||
if not self._deferred_tasks:
|
||||
return
|
||||
|
||||
still_deferred: list[WriteTask] = []
|
||||
for task in self._deferred_tasks:
|
||||
if self._is_remote_ready(task):
|
||||
self._execute_write_task(task)
|
||||
else:
|
||||
still_deferred.append(task)
|
||||
|
||||
self._deferred_tasks = still_deferred
|
||||
|
||||
def _is_remote_ready(self, task: WriteTask) -> bool:
|
||||
"""Check if remote blocks are allocated for this task.
|
||||
|
||||
Args:
|
||||
task: The write task
|
||||
|
||||
Returns:
|
||||
True if remote blocks are ready
|
||||
"""
|
||||
return (
|
||||
task.request_id in self.worker.moriio_wrapper.done_remote_allocate_req_dict
|
||||
)
|
||||
|
||||
def _get_remote_alloc_info(self, request_id: str) -> RemoteAllocInfo:
|
||||
"""Get remote allocation info for a request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID
|
||||
|
||||
Returns:
|
||||
Remote allocation information
|
||||
|
||||
Raises:
|
||||
KeyError: If allocation info is missing
|
||||
"""
|
||||
try:
|
||||
return self.worker.moriio_wrapper.done_remote_allocate_req_dict[request_id]
|
||||
except KeyError as e:
|
||||
raise KeyError(
|
||||
f"Remote allocation info missing for request {request_id}"
|
||||
) from e
|
||||
|
||||
def _execute_write_task(self, task: WriteTask) -> None:
|
||||
"""Execute a single write task.
|
||||
|
||||
Args:
|
||||
task: The write task to execute
|
||||
|
||||
"""
|
||||
# Get remote allocation info
|
||||
request_info = self._get_remote_alloc_info(task.request_id)
|
||||
|
||||
if request_info.block_ids is None:
|
||||
logger.debug("Request %s remote block IDs not ready", task.request_id)
|
||||
return
|
||||
|
||||
# Wait for CUDA event
|
||||
# The attention computation of the current layer cannot
|
||||
# overlap with the kv transfer task,
|
||||
# otherwise it will cause precision issues.
|
||||
# This event is used to synchronize the kv transfer and computation tasks.
|
||||
task.event.synchronize()
|
||||
|
||||
# Update engine ID with DP rank
|
||||
task.dst_engine_id = self.worker.get_engine_name_with_dp(
|
||||
task.dst_engine_id, request_info.decode_dp_rank
|
||||
)
|
||||
|
||||
# Get or create sessions
|
||||
sessions, remote_moriio_meta = self.worker._get_built_session(
|
||||
task.dst_engine_id
|
||||
)
|
||||
|
||||
# Prepare transfer plan
|
||||
plan = self._prepare_transfer_plan(task, request_info, remote_moriio_meta)
|
||||
|
||||
# Execute transfer
|
||||
self._do_layer_write(plan, sessions)
|
||||
|
||||
# Finalize if all layers complete
|
||||
self._finalize_if_complete(task, request_info)
|
||||
|
||||
def _prepare_transfer_plan(
|
||||
self,
|
||||
task: WriteTask,
|
||||
request_info: RemoteAllocInfo,
|
||||
remote_moriio_meta: MoRIIOAgentMetadata,
|
||||
) -> LayerTransferPlan:
|
||||
"""Prepare the transfer plan for a layer.
|
||||
|
||||
Args:
|
||||
task: The write task
|
||||
request_info: Remote allocation information
|
||||
|
||||
Returns:
|
||||
The transfer plan
|
||||
"""
|
||||
# Compute offsets if not cached
|
||||
if request_info.transfer_offset is None:
|
||||
offsets = self.worker._compute_block_transfer_offsets(
|
||||
task.layer_name,
|
||||
task.local_block_ids,
|
||||
request_info.block_ids,
|
||||
remote_moriio_meta,
|
||||
)
|
||||
request_info.transfer_offset = offsets
|
||||
|
||||
# Get session index
|
||||
layer_names = list(self.worker.layer_name_to_local_kv_cache_metadata.keys())
|
||||
sess_idx = layer_names.index(task.layer_name)
|
||||
|
||||
local_off, remote_off, sizes = request_info.transfer_offset
|
||||
|
||||
return LayerTransferPlan(
|
||||
request_id=task.request_id,
|
||||
layer_name=task.layer_name,
|
||||
sess_idx=sess_idx,
|
||||
transfer_local_offsets=local_off,
|
||||
transfer_remote_offsets=remote_off,
|
||||
transfer_sizes=sizes,
|
||||
use_batch=True,
|
||||
)
|
||||
|
||||
def _do_layer_write(self, plan: LayerTransferPlan, sessions: list) -> None:
|
||||
"""Perform the actual layer write.
|
||||
|
||||
Args:
|
||||
plan: The transfer plan
|
||||
sessions: List of transfer sessions
|
||||
"""
|
||||
if plan.use_batch:
|
||||
self.worker.moriio_wrapper.write_remote_data(
|
||||
plan.transfer_sizes,
|
||||
plan.transfer_local_offsets,
|
||||
plan.transfer_remote_offsets,
|
||||
sessions[plan.sess_idx],
|
||||
)
|
||||
else:
|
||||
for i in range(len(plan.transfer_local_offsets)):
|
||||
self.worker.moriio_wrapper.write_remote_data_single(
|
||||
plan.transfer_sizes[i],
|
||||
plan.transfer_local_offsets[i],
|
||||
plan.transfer_remote_offsets[i],
|
||||
plan.sess_idx,
|
||||
)
|
||||
|
||||
def _finalize_if_complete(
|
||||
self, task: WriteTask, request_info: RemoteAllocInfo
|
||||
) -> None:
|
||||
"""Finalize transfer if all layers are complete.
|
||||
|
||||
Args:
|
||||
task: The write task
|
||||
request_info: Remote allocation information
|
||||
"""
|
||||
request_info.writes_done += 1
|
||||
|
||||
if request_info.writes_done >= self.worker.num_layers:
|
||||
# Wait for transfer to complete
|
||||
self.worker.moriio_wrapper.waiting_for_transfer_complete()
|
||||
|
||||
remote_port = task.remote_notify_port + get_port_offset(
|
||||
request_info.decode_dp_rank, self.worker.tp_rank
|
||||
)
|
||||
# Consider using RDMA immediate data in decode side
|
||||
# to eliminate the need for this notification.
|
||||
# Consider including the first gen token from prefill in the notification
|
||||
|
||||
# Send completion notification
|
||||
self.worker.moriio_wrapper.send_notify(
|
||||
task.request_id, task.remote_ip, remote_port
|
||||
)
|
||||
# mark request as done, then we can free the blocks
|
||||
with self.worker.moriio_wrapper.lock:
|
||||
self.worker.moriio_wrapper.done_req_ids.append(task.request_id)
|
||||
del self.worker.moriio_wrapper.done_remote_allocate_req_dict[
|
||||
task.request_id
|
||||
]
|
||||
logger.debug(
|
||||
"Completed transfer for request %s, notified port %d",
|
||||
task.request_id,
|
||||
remote_port,
|
||||
)
|
||||
|
||||
|
||||
class MoRIIOWrapper:
|
||||
"""Wrapper for MoRIIO engine operations.
|
||||
|
||||
Handles both producer and consumer roles for KV cache transfers.
|
||||
|
||||
Args:
|
||||
moriio_engine: MoRIIO engine instance
|
||||
tp_rank: Tensor parallel rank
|
||||
dp_rank: Data parallel rank
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
moriio_engine: Optional["IOEngine"] = None,
|
||||
tp_rank: int = 0,
|
||||
dp_rank: int = 0,
|
||||
):
|
||||
self.tp_rank = tp_rank
|
||||
self.dp_rank = dp_rank
|
||||
self.moriio_engine = moriio_engine
|
||||
self.remote_memory_metadata = None
|
||||
self.local_memory_registered = False
|
||||
self.local_memory_metadata = None
|
||||
self.transfer_status: list[Any] = []
|
||||
self.remote_engine_ip: str | None = None
|
||||
self.notify_port: int | None = None
|
||||
self.lock = threading.Lock()
|
||||
self.done_req_ids: list[str] = []
|
||||
self.done_remote_allocate_req_dict: dict[str, RemoteAllocInfo] = {}
|
||||
self.done_write_cache_req_ids: list[str] = []
|
||||
self.notify_thread: threading.Thread | None = None
|
||||
self.sessions: list[IOEngine.Session] = []
|
||||
self.paths: dict[str, zmq.Socket] = {}
|
||||
|
||||
def set_moriio_engine(self, moriio_engine):
|
||||
assert moriio_engine is not None, (
|
||||
"You Cannot pass None engine to MoRIIOWrapper!"
|
||||
)
|
||||
self.moriio_engine = moriio_engine
|
||||
|
||||
def set_backend_type(self, backend_type):
|
||||
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
||||
qp_per_transfer = envs.VLLM_MORIIO_QP_PER_TRANSFER
|
||||
post_batch_size = envs.VLLM_MORIIO_POST_BATCH_SIZE
|
||||
num_worker_threads = envs.VLLM_MORIIO_NUM_WORKERS
|
||||
poll_mode = PollCqMode.POLLING
|
||||
rdma_cfg = RdmaBackendConfig(
|
||||
qp_per_transfer,
|
||||
post_batch_size,
|
||||
num_worker_threads,
|
||||
poll_mode,
|
||||
)
|
||||
self.moriio_engine.create_backend(backend_type, rdma_cfg)
|
||||
|
||||
def get_agent_metadata(self):
|
||||
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
||||
engine_metadata = self.moriio_engine.get_engine_desc()
|
||||
engine_metadata_packed = engine_metadata.pack()
|
||||
return engine_metadata_packed
|
||||
|
||||
def register_remote_engine(self, remote_packed_engine_metadata):
|
||||
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
||||
consumer_engine_metadata = EngineDesc.unpack(remote_packed_engine_metadata)
|
||||
self.moriio_engine.register_remote_engine(consumer_engine_metadata)
|
||||
return consumer_engine_metadata.key
|
||||
|
||||
def register_local_tensor(self, tensor: torch.Tensor):
|
||||
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
||||
try:
|
||||
self.local_memory_metadata = self.moriio_engine.register_torch_tensor(
|
||||
tensor
|
||||
)
|
||||
assert self.local_memory_metadata is not None, (
|
||||
"register_torch_tensor returned None"
|
||||
)
|
||||
local_memory_metadata_packed = self.local_memory_metadata.pack()
|
||||
except Exception as e:
|
||||
raise MoRIIOError(f"Failed to register local memory: {e}") from e
|
||||
self.local_memory_registered = True
|
||||
return local_memory_metadata_packed
|
||||
|
||||
def get_unpack_memory_metadata(self, packed_memory_metadata):
|
||||
return MemoryDesc.unpack(packed_memory_metadata)
|
||||
|
||||
def build_session(self, local_memory_metadata, remote_memory_metadata):
|
||||
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
||||
return self.moriio_engine.create_session(
|
||||
local_memory_metadata, remote_memory_metadata
|
||||
)
|
||||
|
||||
def read_remote_data(
|
||||
self, transfer_size_byte, local_offset=0, remote_offset=0, session=None
|
||||
):
|
||||
assert self.local_memory_registered, "You have not register local memory data!"
|
||||
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
||||
transfer_status = session.batch_read(
|
||||
local_offset,
|
||||
remote_offset,
|
||||
transfer_size_byte,
|
||||
self.moriio_engine.allocate_transfer_uid(),
|
||||
)
|
||||
|
||||
return transfer_status
|
||||
|
||||
def write_remote_data(
|
||||
self, transfer_size_byte, local_offset=0, remote_offset=0, session=None
|
||||
):
|
||||
assert self.local_memory_registered, "You have not register local memory data!"
|
||||
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
||||
write_uid = self.moriio_engine.allocate_transfer_uid()
|
||||
|
||||
transfer_status = session.batch_write(
|
||||
local_offset, remote_offset, transfer_size_byte, write_uid
|
||||
)
|
||||
with self.lock:
|
||||
self.transfer_status.append(transfer_status)
|
||||
|
||||
def write_remote_data_single(
|
||||
self, transfer_size_byte, local_offset=0, remote_offset=0, sess_idx=0
|
||||
):
|
||||
assert self.local_memory_registered, "You have not register local memory data!"
|
||||
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
||||
transfer_status = self.sessions[sess_idx].write(
|
||||
local_offset,
|
||||
remote_offset,
|
||||
transfer_size_byte,
|
||||
self.moriio_engine.allocate_transfer_uid(),
|
||||
)
|
||||
with self.lock:
|
||||
self.transfer_status.append(transfer_status)
|
||||
|
||||
def waiting_for_transfer_complete(self):
|
||||
if not self.transfer_status:
|
||||
return
|
||||
|
||||
transfers_to_wait = []
|
||||
with self.lock:
|
||||
transfers_to_wait = self.transfer_status[:]
|
||||
self.transfer_status.clear()
|
||||
|
||||
for status in transfers_to_wait:
|
||||
try:
|
||||
status.Wait()
|
||||
if not status.Succeeded():
|
||||
logger.error(
|
||||
"Transfer failed: %s, Code: %s", status.Message(), status.Code()
|
||||
)
|
||||
raise TransferError("MoRIIO transfer failed!")
|
||||
except Exception as e:
|
||||
logger.error("Transfer %s failed: %s", status, e)
|
||||
raise
|
||||
|
||||
def async_wait_reqid(self):
|
||||
assert self.notify_port is not None, "Notify port cannot be None"
|
||||
|
||||
if self.notify_thread is not None:
|
||||
return
|
||||
|
||||
def _async_wait():
|
||||
host = "*"
|
||||
path = make_zmq_path("tcp", host, self.notify_port)
|
||||
logger.info("Node starting to listen notify from path = %s", path)
|
||||
|
||||
with zmq_ctx(zmq.ROUTER, path) as sock:
|
||||
while True:
|
||||
try:
|
||||
identity, msg = sock.recv_multipart()
|
||||
self._handle_message(msg)
|
||||
except Exception as e:
|
||||
logger.error("Error processing message: %s", e)
|
||||
raise HandshakeError(f"Error processing message: {e}") from e
|
||||
|
||||
self.notify_thread = threading.Thread(
|
||||
target=_async_wait, daemon=True, name="moriio-notify-listener"
|
||||
)
|
||||
self.notify_thread.start()
|
||||
|
||||
def _handle_message(self, msg: bytes):
|
||||
"""Handles incoming messages from remote nodes."""
|
||||
# Handles incoming remote messages:
|
||||
# Prefill Role:
|
||||
# [write] mode: receives block information (allocation)
|
||||
# [read] mode: receives block release messages from decode side
|
||||
# Decode Role:
|
||||
# [write] mode: receives KV cache write completion notifications
|
||||
handled = False
|
||||
try:
|
||||
data = msgpack.loads(msg)
|
||||
if isinstance(data, dict) and "req_id" in data:
|
||||
self._handle_structured_message(data)
|
||||
|
||||
return
|
||||
except (msgpack.exceptions.ExtraData, msgpack.exceptions.UnpackException):
|
||||
logger.debug("Failed to decode msgpack message, will try as string")
|
||||
pass
|
||||
|
||||
try:
|
||||
msg_str = msg.decode("UTF-8")
|
||||
if msg_str.startswith(MoRIIOConstants.COMPLETION_PREFIX):
|
||||
self._handle_completion_message(msg_str)
|
||||
handled = True
|
||||
except UnicodeDecodeError:
|
||||
logger.warning("Received non-UTF8 message: %s", msg_str)
|
||||
if not handled:
|
||||
raise MoRIIOError(f"Unhandled message format: {msg_str}")
|
||||
|
||||
def _handle_structured_message(self, data: dict):
|
||||
assert get_role() == ROLE.PRODUCER, "Only prefill can get block messages"
|
||||
req_id = data["req_id"]
|
||||
block_notify_list = data.get("block_notify_list", [])
|
||||
decode_dp_rank = data.get("decode_rank", 0)
|
||||
assert len(block_notify_list) > 0, (
|
||||
"block_notify_list cannot be empty in remote allocate message"
|
||||
)
|
||||
|
||||
with self.lock:
|
||||
self.done_remote_allocate_req_dict[req_id] = RemoteAllocInfo(
|
||||
block_ids=block_notify_list, decode_dp_rank=decode_dp_rank
|
||||
)
|
||||
|
||||
def _handle_completion_message(self, msg: str):
|
||||
with self.lock:
|
||||
if get_role() == ROLE.PRODUCER:
|
||||
self.done_req_ids.append(msg)
|
||||
else:
|
||||
self.done_write_cache_req_ids.append(msg)
|
||||
|
||||
def send_notify(self, req_ids, remote_ip, remote_port):
|
||||
if not remote_ip or not remote_port:
|
||||
logger.warning("Missing remote_ip or remote_port for notification")
|
||||
return
|
||||
|
||||
path = make_zmq_path("tcp", remote_ip, remote_port)
|
||||
|
||||
if path not in self.paths:
|
||||
ctx = zmq.Context.instance()
|
||||
sock = make_zmq_socket(
|
||||
ctx=ctx, path=path, socket_type=zmq.DEALER, bind=False
|
||||
)
|
||||
self.paths[path] = sock
|
||||
|
||||
req_list = req_ids if isinstance(req_ids, list) else [req_ids]
|
||||
|
||||
sock = self.paths[path]
|
||||
try:
|
||||
for req_id in req_list:
|
||||
if not isinstance(req_id, str):
|
||||
logger.warning(
|
||||
"Invalid req_id type: %s, expected str", type(req_id)
|
||||
)
|
||||
continue
|
||||
sock.send(req_id.encode("utf-8"))
|
||||
except Exception as e:
|
||||
logger.error("Failed to send notification to %s: %s", path, e)
|
||||
self.paths.pop(path, None)
|
||||
raise
|
||||
|
||||
def pop_finished_req_ids(self):
|
||||
# producer invocation: get the set of completed requests at the decode
|
||||
with self.lock:
|
||||
done_send = set(self.done_req_ids)
|
||||
self.done_req_ids = []
|
||||
return done_send
|
||||
|
||||
def pop_finished_write_req_ids(self):
|
||||
# Call the consumer in write mode to get the collection after write completion
|
||||
with self.lock:
|
||||
done_write_cache = set(self.done_write_cache_req_ids)
|
||||
self.done_write_cache_req_ids = []
|
||||
return done_write_cache
|
||||
|
||||
def shutdown(self):
|
||||
logger.debug("Closing MoRIIOWrapper and cleaning up ZMQ sockets")
|
||||
for path, sock in self.paths.items():
|
||||
try:
|
||||
sock.close(linger=0)
|
||||
logger.debug("Closed ZMQ socket for path: %s", path)
|
||||
except Exception as e:
|
||||
logger.warning("Error closing ZMQ socket for path %s: %s", path, e)
|
||||
self.paths.clear()
|
||||
18
vllm/envs.py
18
vllm/envs.py
@ -203,6 +203,10 @@ if TYPE_CHECKING:
|
||||
VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True
|
||||
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: int | None = None
|
||||
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 480
|
||||
VLLM_MORIIO_CONNECTOR_READ_MODE: bool = False
|
||||
VLLM_MORIIO_QP_PER_TRANSFER: int = 1
|
||||
VLLM_MORIIO_POST_BATCH_SIZE: int = -1
|
||||
VLLM_MORIIO_NUM_WORKERS: int = 1
|
||||
VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT: int = 480
|
||||
VLLM_USE_CUDNN_PREFILL: bool = False
|
||||
VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False
|
||||
@ -1379,6 +1383,20 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT": lambda: int(
|
||||
os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "480")
|
||||
),
|
||||
# Controls the read mode for the Mori-IO connector
|
||||
"VLLM_MORIIO_CONNECTOR_READ_MODE": lambda: (
|
||||
os.getenv("VLLM_MORIIO_CONNECTOR_READ_MODE", "False").lower() in ("true", "1")
|
||||
),
|
||||
# Controls the QP (Queue Pair) per transfer configuration for the Mori-IO connector
|
||||
"VLLM_MORIIO_QP_PER_TRANSFER": lambda: int(
|
||||
os.getenv("VLLM_MORIIO_QP_PER_TRANSFER", "1")
|
||||
),
|
||||
# Controls the post-processing batch size for the Mori-IO connector
|
||||
"VLLM_MORIIO_POST_BATCH_SIZE": lambda: int(
|
||||
os.getenv("VLLM_MORIIO_POST_BATCH_SIZE", "-1")
|
||||
),
|
||||
# Controls the number of workers for Mori operations for the Mori-IO connector
|
||||
"VLLM_MORIIO_NUM_WORKERS": lambda: int(os.getenv("VLLM_MORIIO_NUM_WORKERS", "1")),
|
||||
# Timeout (in seconds) for MooncakeConnector in PD disaggregated setup.
|
||||
"VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT": lambda: int(
|
||||
os.getenv("VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT", "480")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user