mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-03 04:44:25 +08:00
split large file
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
parent
0b0c33d59e
commit
66e2332682
@ -181,7 +181,7 @@ KVConnectorFactory.register_connector(
|
|||||||
|
|
||||||
KVConnectorFactory.register_connector(
|
KVConnectorFactory.register_connector(
|
||||||
"MoRIIOConnector",
|
"MoRIIOConnector",
|
||||||
"vllm.distributed.kv_transfer.kv_connector.v1.moriio_connector",
|
"vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector",
|
||||||
"MoRIIOConnector",
|
"MoRIIOConnector",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -0,0 +1,321 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import contextlib
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from collections.abc import Iterator
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
|
import msgspec
|
||||||
|
import torch
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
from vllm import envs
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||||
|
KVConnectorMetadata,
|
||||||
|
)
|
||||||
|
from vllm.distributed.parallel_state import (
|
||||||
|
get_tensor_model_parallel_rank,
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
|
)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils.network_utils import (
|
||||||
|
get_ip,
|
||||||
|
get_open_port,
|
||||||
|
make_zmq_socket,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
pass
|
||||||
|
|
||||||
|
from dataclasses import field
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
Transfer = tuple[int, float]
|
||||||
|
EngineId = str
|
||||||
|
ReqId = str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WriteTask:
|
||||||
|
request_id: str
|
||||||
|
dst_engine_id: str
|
||||||
|
local_block_ids: list[int]
|
||||||
|
remote_block_ids_hint: list[int] | None
|
||||||
|
layer_name: str
|
||||||
|
event: torch.cuda.Event
|
||||||
|
remote_notify_port: int
|
||||||
|
remote_ip: str
|
||||||
|
enqueue_time: float = field(default_factory=time.perf_counter)
|
||||||
|
retried: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LayerTransferPlan:
|
||||||
|
"""Plan for transferring a single layer."""
|
||||||
|
|
||||||
|
request_id: str
|
||||||
|
layer_name: str
|
||||||
|
sess_idx: int
|
||||||
|
transfer_local_offsets: list[int]
|
||||||
|
transfer_remote_offsets: list[int]
|
||||||
|
transfer_sizes: list[int]
|
||||||
|
use_batch: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RemoteAllocInfo:
|
||||||
|
"""Information about remote block allocation."""
|
||||||
|
|
||||||
|
block_ids: list[int]
|
||||||
|
writes_done: int = 0
|
||||||
|
decode_dp_rank: int = 0
|
||||||
|
transfer_offset: tuple[list[int], list[int], list[int]] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ROLE(Enum):
|
||||||
|
PRODUCER = "producer"
|
||||||
|
CONSUMER = "consumer"
|
||||||
|
NOTINIT = "notinit"
|
||||||
|
|
||||||
|
|
||||||
|
class MoRIIOAgentMetadata(
|
||||||
|
msgspec.Struct,
|
||||||
|
omit_defaults=True, # type: ignore[call-arg]
|
||||||
|
# required for @cached_property.d
|
||||||
|
dict=True,
|
||||||
|
):
|
||||||
|
engine_id: str
|
||||||
|
agent_metadata: bytes
|
||||||
|
kv_caches_base_addr: list[int]
|
||||||
|
num_blocks: int
|
||||||
|
block_len: int
|
||||||
|
attn_backend_name: str
|
||||||
|
|
||||||
|
|
||||||
|
class RoleManager:
|
||||||
|
"""Manages role state across the connector."""
|
||||||
|
|
||||||
|
_instance: Optional["RoleManager"] = None
|
||||||
|
_lock = threading.Lock()
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._role: ROLE = ROLE.NOTINIT
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_instance(cls) -> "RoleManager":
|
||||||
|
if cls._instance is None:
|
||||||
|
with cls._lock:
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = cls()
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def set_role(self, role: ROLE) -> None:
|
||||||
|
"""Set the current role."""
|
||||||
|
with self._lock:
|
||||||
|
self._role = role
|
||||||
|
|
||||||
|
def get_role(self) -> ROLE:
|
||||||
|
"""Get the current role."""
|
||||||
|
return self._role
|
||||||
|
|
||||||
|
|
||||||
|
def set_role(role: ROLE):
|
||||||
|
"""Set the global role."""
|
||||||
|
RoleManager.get_instance().set_role(role)
|
||||||
|
|
||||||
|
|
||||||
|
def get_role() -> ROLE:
|
||||||
|
"""Get the global role."""
|
||||||
|
return RoleManager.get_instance().get_role()
|
||||||
|
|
||||||
|
|
||||||
|
class MoRIIOMode(Enum):
|
||||||
|
READ = "read"
|
||||||
|
WRITE = "write"
|
||||||
|
|
||||||
|
|
||||||
|
class MoRIIOError(Exception):
|
||||||
|
"""Base exception for MoRIIO operations."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class HandshakeError(MoRIIOError):
|
||||||
|
"""Exception raised when handshake fails."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TransferError(MoRIIOError):
|
||||||
|
"""Exception raised when transfer fails."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def get_moriio_mode() -> MoRIIOMode:
|
||||||
|
read_mode = envs.VLLM_MORIIO_CONNECTOR_READ_MODE
|
||||||
|
logger.debug("MoRIIO Connector read_mode: %s", read_mode)
|
||||||
|
if read_mode:
|
||||||
|
return MoRIIOMode.READ
|
||||||
|
else:
|
||||||
|
return MoRIIOMode.WRITE
|
||||||
|
|
||||||
|
|
||||||
|
def get_port_offset(dp_rank: int, tp_rank: int, tp_size: int = 1) -> int:
|
||||||
|
return (dp_rank) * tp_size + tp_rank
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MoRIIOConfig:
|
||||||
|
local_ip: str
|
||||||
|
local_kv_port: int
|
||||||
|
proxy_ip: str
|
||||||
|
local_ping_port: int
|
||||||
|
proxy_ping_port: int
|
||||||
|
http_port: int
|
||||||
|
handshake_port: int
|
||||||
|
notify_port: int
|
||||||
|
tp_rank: int
|
||||||
|
dp_rank: int
|
||||||
|
dp_size: int
|
||||||
|
tp_size: int
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_vllm_config(cls, vllm_config: VllmConfig) -> "MoRIIOConfig":
|
||||||
|
# Port Configuration:
|
||||||
|
# local_ping_port -> Outgoing heartbeat to proxy
|
||||||
|
# proxy_ping_port -> Remote proxy's heartbeat ingress port
|
||||||
|
# http_port -> Instance's HTTP service endpoint
|
||||||
|
# local_kv_port -> service port for mori engine
|
||||||
|
# notify_port -> For synchronizing stages between prefill and decode
|
||||||
|
# handshake_port -> For initial handshake between mori engine
|
||||||
|
|
||||||
|
# TODO : merge notify_port and handshake_port to simplify port management
|
||||||
|
# supports non-contiguous ports
|
||||||
|
assert vllm_config.kv_transfer_config is not None, (
|
||||||
|
"kv_transfer_config must be set for MoRIIOConnector"
|
||||||
|
)
|
||||||
|
kv_transfer_config = vllm_config.kv_transfer_config
|
||||||
|
extra_config = kv_transfer_config.kv_connector_extra_config
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||||
|
base_notify_port = int(extra_config["notify_port"])
|
||||||
|
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
port_offset = get_port_offset(dp_rank, tp_rank)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
local_ip=get_ip(),
|
||||||
|
local_kv_port=get_open_port(),
|
||||||
|
proxy_ip=extra_config["proxy_ip"],
|
||||||
|
local_ping_port=get_open_port(),
|
||||||
|
proxy_ping_port=int(extra_config["proxy_ping_port"]),
|
||||||
|
http_port=int(extra_config["http_port"]),
|
||||||
|
handshake_port=int(extra_config["handshake_port"]),
|
||||||
|
notify_port=base_notify_port + port_offset,
|
||||||
|
tp_rank=tp_rank,
|
||||||
|
dp_rank=dp_rank,
|
||||||
|
dp_size=dp_size,
|
||||||
|
tp_size=tp_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MoRIIOConstants:
|
||||||
|
"""Constants for MoRIIO connector."""
|
||||||
|
|
||||||
|
# ZMQ message types
|
||||||
|
GET_META_MSG = b"get_meta_msg"
|
||||||
|
POP_DONE_RECV = b"pop_done_recv"
|
||||||
|
OVER = b"OVER"
|
||||||
|
COMPLETION_PREFIX = "cmpl"
|
||||||
|
|
||||||
|
PING_INTERVAL = 5
|
||||||
|
MAX_PING_RETRIES = 100
|
||||||
|
DEFAULT_HANDSHAKE_PORT = "6301"
|
||||||
|
DEFAULT_NOTIFY_PORT = "61005"
|
||||||
|
|
||||||
|
VLLM_MORI_READ_ABORT_REQUEST_TIMEOUT = 3600
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ReqMeta:
|
||||||
|
"""Metadata for a single request."""
|
||||||
|
|
||||||
|
local_block_ids: list[int]
|
||||||
|
remote_block_ids: list[int]
|
||||||
|
remote_host: str
|
||||||
|
remote_port: int
|
||||||
|
remote_handshake_port: int
|
||||||
|
remote_notify_port: int
|
||||||
|
remote_engine_id: str
|
||||||
|
tp_size: int
|
||||||
|
remote_dp_size: int
|
||||||
|
|
||||||
|
|
||||||
|
class MoRIIOConnectorMetadata(KVConnectorMetadata):
|
||||||
|
def __init__(self):
|
||||||
|
self.reqs_to_recv: dict[ReqId, ReqMeta] = {}
|
||||||
|
self.reqs_to_save: dict[ReqId, ReqMeta] = {}
|
||||||
|
self.reqs_to_send: dict[ReqId, float] = {}
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return_str = ""
|
||||||
|
for req_id, req_meta in self.reqs_to_recv.items():
|
||||||
|
return_str += (
|
||||||
|
f"{req_id = },{req_meta.local_block_ids = },"
|
||||||
|
f"{req_meta.remote_host = },{req_meta.remote_port = }"
|
||||||
|
f"{req_meta.remote_engine_id = },{req_meta.tp_size = }"
|
||||||
|
)
|
||||||
|
return_str = f"MoRIIOConnectorMetadata:reqs_to_recv:{return_str},"
|
||||||
|
|
||||||
|
for req_id, expiry in self.reqs_to_send.items():
|
||||||
|
return_str += f"{req_id = },{expiry = }"
|
||||||
|
return_str = f"MoRIIOConnectorMetadata:reqs_to_send:{return_str},"
|
||||||
|
return return_str
|
||||||
|
|
||||||
|
def add_new_req(
|
||||||
|
self,
|
||||||
|
request_id: ReqId,
|
||||||
|
local_block_ids: list[int],
|
||||||
|
kv_transfer_params: dict[str, Any],
|
||||||
|
write_mode=False,
|
||||||
|
):
|
||||||
|
_req = ReqMeta(
|
||||||
|
local_block_ids=local_block_ids,
|
||||||
|
remote_block_ids=kv_transfer_params["remote_block_ids"],
|
||||||
|
remote_engine_id=kv_transfer_params["remote_engine_id"],
|
||||||
|
remote_host=kv_transfer_params["remote_host"],
|
||||||
|
remote_port=kv_transfer_params["remote_port"],
|
||||||
|
remote_handshake_port=kv_transfer_params["remote_handshake_port"],
|
||||||
|
remote_notify_port=kv_transfer_params["remote_notify_port"],
|
||||||
|
tp_size=kv_transfer_params.get("tp_size", 1),
|
||||||
|
remote_dp_size=kv_transfer_params.get("remote_dp_size", 1),
|
||||||
|
)
|
||||||
|
if write_mode:
|
||||||
|
self.reqs_to_save[request_id] = _req
|
||||||
|
else:
|
||||||
|
self.reqs_to_recv[request_id] = _req
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]:
|
||||||
|
"""Context manager for a ZMQ socket"""
|
||||||
|
|
||||||
|
if socket_type not in (zmq.ROUTER, zmq.REQ, zmq.DEALER):
|
||||||
|
raise ValueError(f"Unexpected socket type: {socket_type}")
|
||||||
|
|
||||||
|
ctx: zmq.Context | None = None
|
||||||
|
try:
|
||||||
|
ctx = zmq.Context() # type: ignore[attr-defined]
|
||||||
|
yield make_zmq_socket(
|
||||||
|
ctx=ctx, path=addr, socket_type=socket_type, bind=socket_type == zmq.ROUTER
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
if ctx is not None:
|
||||||
|
ctx.destroy(linger=0)
|
||||||
@ -1,17 +1,13 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import contextlib
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import queue
|
import queue
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Iterator
|
|
||||||
from concurrent.futures import Future, ThreadPoolExecutor
|
from concurrent.futures import Future, ThreadPoolExecutor
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
from weakref import ref as weakref_ref
|
|
||||||
|
|
||||||
import msgpack
|
import msgpack
|
||||||
import msgspec
|
import msgspec
|
||||||
@ -19,7 +15,6 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import zmq
|
import zmq
|
||||||
|
|
||||||
from vllm import envs
|
|
||||||
from vllm.attention.selector import get_attn_backend
|
from vllm.attention.selector import get_attn_backend
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||||
@ -27,8 +22,29 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
|||||||
KVConnectorMetadata,
|
KVConnectorMetadata,
|
||||||
KVConnectorRole,
|
KVConnectorRole,
|
||||||
)
|
)
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import (
|
||||||
|
ROLE,
|
||||||
|
EngineId,
|
||||||
|
HandshakeError,
|
||||||
|
MoRIIOAgentMetadata,
|
||||||
|
MoRIIOConfig,
|
||||||
|
MoRIIOConnectorMetadata,
|
||||||
|
MoRIIOConstants,
|
||||||
|
MoRIIOMode,
|
||||||
|
ReqId,
|
||||||
|
ReqMeta,
|
||||||
|
WriteTask,
|
||||||
|
get_moriio_mode,
|
||||||
|
get_port_offset,
|
||||||
|
get_role,
|
||||||
|
set_role,
|
||||||
|
zmq_ctx,
|
||||||
|
)
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_engine import (
|
||||||
|
MoRIIOWrapper,
|
||||||
|
MoRIIOWriter,
|
||||||
|
)
|
||||||
from vllm.distributed.parallel_state import (
|
from vllm.distributed.parallel_state import (
|
||||||
get_tensor_model_parallel_rank,
|
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
get_tp_group,
|
get_tp_group,
|
||||||
get_world_group,
|
get_world_group,
|
||||||
@ -37,7 +53,6 @@ from vllm.forward_context import ForwardContext
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils.network_utils import (
|
from vllm.utils.network_utils import (
|
||||||
get_ip,
|
get_ip,
|
||||||
get_open_port,
|
|
||||||
make_zmq_path,
|
make_zmq_path,
|
||||||
make_zmq_socket,
|
make_zmq_socket,
|
||||||
)
|
)
|
||||||
@ -50,43 +65,13 @@ if TYPE_CHECKING:
|
|||||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
from vllm.v1.request import Request
|
from vllm.v1.request import Request
|
||||||
|
|
||||||
from dataclasses import field
|
|
||||||
from enum import Enum
|
|
||||||
from queue import Empty, Queue
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
Transfer = tuple[int, float]
|
|
||||||
EngineId = str
|
|
||||||
ReqId = str
|
|
||||||
|
|
||||||
|
|
||||||
class MoRIIOConstants:
|
|
||||||
"""Constants for MoRIIO connector."""
|
|
||||||
|
|
||||||
# ZMQ message types
|
|
||||||
GET_META_MSG = b"get_meta_msg"
|
|
||||||
POP_DONE_RECV = b"pop_done_recv"
|
|
||||||
OVER = b"OVER"
|
|
||||||
COMPLETION_PREFIX = "cmpl"
|
|
||||||
|
|
||||||
PING_INTERVAL = 5
|
|
||||||
MAX_PING_RETRIES = 100
|
|
||||||
DEFAULT_HANDSHAKE_PORT = "6301"
|
|
||||||
DEFAULT_NOTIFY_PORT = "61005"
|
|
||||||
|
|
||||||
VLLM_MORI_READ_ABORT_REQUEST_TIMEOUT = 3600
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from mori.io import (
|
from mori.io import (
|
||||||
BackendType,
|
BackendType,
|
||||||
EngineDesc,
|
|
||||||
IOEngine,
|
IOEngine,
|
||||||
IOEngineConfig,
|
IOEngineConfig,
|
||||||
MemoryDesc,
|
|
||||||
PollCqMode,
|
|
||||||
RdmaBackendConfig,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("MoRIIO is available")
|
logger.info("MoRIIO is available")
|
||||||
@ -96,803 +81,8 @@ except ImportError:
|
|||||||
MoRIIO_enabled = False
|
MoRIIO_enabled = False
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
def is_moriio_available() -> bool:
|
||||||
class WriteTask:
|
return MoRIIO_enabled
|
||||||
request_id: str
|
|
||||||
dst_engine_id: str
|
|
||||||
local_block_ids: list[int]
|
|
||||||
remote_block_ids_hint: list[int] | None
|
|
||||||
layer_name: str
|
|
||||||
event: torch.cuda.Event
|
|
||||||
remote_notify_port: int
|
|
||||||
remote_ip: str
|
|
||||||
enqueue_time: float = field(default_factory=time.perf_counter)
|
|
||||||
retried: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class LayerTransferPlan:
|
|
||||||
"""Plan for transferring a single layer."""
|
|
||||||
|
|
||||||
request_id: str
|
|
||||||
layer_name: str
|
|
||||||
sess_idx: int
|
|
||||||
transfer_local_offsets: list[int]
|
|
||||||
transfer_remote_offsets: list[int]
|
|
||||||
transfer_sizes: list[int]
|
|
||||||
use_batch: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class RemoteAllocInfo:
|
|
||||||
"""Information about remote block allocation."""
|
|
||||||
|
|
||||||
block_ids: list[int]
|
|
||||||
writes_done: int = 0
|
|
||||||
decode_dp_rank: int = 0
|
|
||||||
transfer_offset: tuple[list[int], list[int], list[int]] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class ROLE(Enum):
|
|
||||||
PRODUCER = "producer"
|
|
||||||
CONSUMER = "consumer"
|
|
||||||
NOTINIT = "notinit"
|
|
||||||
|
|
||||||
|
|
||||||
class MoRIIOAgentMetadata(
|
|
||||||
msgspec.Struct,
|
|
||||||
omit_defaults=True, # type: ignore[call-arg]
|
|
||||||
# required for @cached_property.d
|
|
||||||
dict=True,
|
|
||||||
):
|
|
||||||
engine_id: str
|
|
||||||
agent_metadata: bytes
|
|
||||||
kv_caches_base_addr: list[int]
|
|
||||||
num_blocks: int
|
|
||||||
block_len: int
|
|
||||||
attn_backend_name: str
|
|
||||||
|
|
||||||
|
|
||||||
class RoleManager:
|
|
||||||
"""Manages role state across the connector."""
|
|
||||||
|
|
||||||
_instance: Optional["RoleManager"] = None
|
|
||||||
_lock = threading.Lock()
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._role: ROLE = ROLE.NOTINIT
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_instance(cls) -> "RoleManager":
|
|
||||||
if cls._instance is None:
|
|
||||||
with cls._lock:
|
|
||||||
if cls._instance is None:
|
|
||||||
cls._instance = cls()
|
|
||||||
return cls._instance
|
|
||||||
|
|
||||||
def set_role(self, role: ROLE) -> None:
|
|
||||||
"""Set the current role."""
|
|
||||||
with self._lock:
|
|
||||||
self._role = role
|
|
||||||
|
|
||||||
def get_role(self) -> ROLE:
|
|
||||||
"""Get the current role."""
|
|
||||||
return self._role
|
|
||||||
|
|
||||||
|
|
||||||
def set_role(role: ROLE):
|
|
||||||
"""Set the global role."""
|
|
||||||
RoleManager.get_instance().set_role(role)
|
|
||||||
|
|
||||||
|
|
||||||
def get_role() -> ROLE:
|
|
||||||
"""Get the global role."""
|
|
||||||
return RoleManager.get_instance().get_role()
|
|
||||||
|
|
||||||
|
|
||||||
class MoRIIOMode(Enum):
|
|
||||||
READ = "read"
|
|
||||||
WRITE = "write"
|
|
||||||
|
|
||||||
|
|
||||||
class MoRIIOError(Exception):
|
|
||||||
"""Base exception for MoRIIO operations."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class HandshakeError(MoRIIOError):
|
|
||||||
"""Exception raised when handshake fails."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class TransferError(MoRIIOError):
|
|
||||||
"""Exception raised when transfer fails."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def get_moriio_mode() -> MoRIIOMode:
|
|
||||||
read_mode = envs.VLLM_MORIIO_CONNECTOR_READ_MODE
|
|
||||||
logger.debug("MoRIIO Connector read_mode: %s", read_mode)
|
|
||||||
if read_mode:
|
|
||||||
return MoRIIOMode.READ
|
|
||||||
else:
|
|
||||||
return MoRIIOMode.WRITE
|
|
||||||
|
|
||||||
|
|
||||||
def get_port_offset(dp_rank: int, tp_rank: int, tp_size: int = 1) -> int:
|
|
||||||
return (dp_rank) * tp_size + tp_rank
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MoRIIOConfig:
|
|
||||||
local_ip: str
|
|
||||||
local_kv_port: int
|
|
||||||
proxy_ip: str
|
|
||||||
local_ping_port: int
|
|
||||||
proxy_ping_port: int
|
|
||||||
http_port: int
|
|
||||||
handshake_port: int
|
|
||||||
notify_port: int
|
|
||||||
tp_rank: int
|
|
||||||
dp_rank: int
|
|
||||||
dp_size: int
|
|
||||||
tp_size: int
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_vllm_config(cls, vllm_config: VllmConfig) -> "MoRIIOConfig":
|
|
||||||
# Port Configuration:
|
|
||||||
# local_ping_port -> Outgoing heartbeat to proxy
|
|
||||||
# proxy_ping_port -> Remote proxy's heartbeat ingress port
|
|
||||||
# http_port -> Instance's HTTP service endpoint
|
|
||||||
# local_kv_port -> service port for mori engine
|
|
||||||
# notify_port -> For synchronizing stages between prefill and decode
|
|
||||||
# handshake_port -> For initial handshake between mori engine
|
|
||||||
|
|
||||||
# TODO : merge notify_port and handshake_port to simplify port management
|
|
||||||
# supports non-contiguous ports
|
|
||||||
assert vllm_config.kv_transfer_config is not None, (
|
|
||||||
"kv_transfer_config must be set for MoRIIOConnector"
|
|
||||||
)
|
|
||||||
kv_transfer_config = vllm_config.kv_transfer_config
|
|
||||||
extra_config = kv_transfer_config.kv_connector_extra_config
|
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
|
||||||
base_notify_port = int(extra_config["notify_port"])
|
|
||||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
port_offset = get_port_offset(dp_rank, tp_rank)
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
local_ip=get_ip(),
|
|
||||||
local_kv_port=get_open_port(),
|
|
||||||
proxy_ip=extra_config["proxy_ip"],
|
|
||||||
local_ping_port=get_open_port(),
|
|
||||||
proxy_ping_port=int(extra_config["proxy_ping_port"]),
|
|
||||||
http_port=int(extra_config["http_port"]),
|
|
||||||
handshake_port=int(extra_config["handshake_port"]),
|
|
||||||
notify_port=base_notify_port + port_offset,
|
|
||||||
tp_rank=tp_rank,
|
|
||||||
dp_rank=dp_rank,
|
|
||||||
dp_size=dp_size,
|
|
||||||
tp_size=tp_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
"""Write task execution logic for MoRIIO connector."""
|
|
||||||
|
|
||||||
|
|
||||||
class MoRIIOWriter:
|
|
||||||
"""Handles write operations for KV cache transfers.
|
|
||||||
Implements distributed KV cache transfer using the MoRIIO library
|
|
||||||
for RDMA-based communication between prefill and decode instances."""
|
|
||||||
|
|
||||||
def __init__(self, worker: "MoRIIOConnectorWorker"):
|
|
||||||
"""Initialize the writer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
worker: Reference to the parent worker
|
|
||||||
"""
|
|
||||||
self._worker_ref: weakref_ref[MoRIIOConnectorWorker] = weakref_ref(worker)
|
|
||||||
self._write_task_q: Queue[WriteTask] = Queue()
|
|
||||||
self._write_worker_started = False
|
|
||||||
self._write_worker_lock = threading.Lock()
|
|
||||||
self._deferred_tasks: list[WriteTask] = []
|
|
||||||
|
|
||||||
@property
|
|
||||||
def worker(self) -> "MoRIIOConnectorWorker":
|
|
||||||
"""Get the worker instance.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The parent worker instance
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If worker has been garbage collected
|
|
||||||
"""
|
|
||||||
worker = self._worker_ref()
|
|
||||||
if worker is None:
|
|
||||||
raise RuntimeError("Parent worker has been garbage collected")
|
|
||||||
return worker
|
|
||||||
|
|
||||||
def ensure_worker_started(self) -> None:
|
|
||||||
"""Ensure the background write worker is running."""
|
|
||||||
if self._write_worker_started:
|
|
||||||
return
|
|
||||||
self._write_worker_started = True
|
|
||||||
with self._write_worker_lock:
|
|
||||||
thread = threading.Thread(
|
|
||||||
target=self._write_worker_loop, daemon=True, name="moriio-write-worker"
|
|
||||||
)
|
|
||||||
thread.start()
|
|
||||||
logger.info("Started MoRIIO write worker thread")
|
|
||||||
|
|
||||||
def schedule_write(self, task: WriteTask) -> None:
|
|
||||||
"""Schedule a write task.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task: The write task to schedule
|
|
||||||
"""
|
|
||||||
self.ensure_worker_started()
|
|
||||||
self._write_task_q.put(task)
|
|
||||||
|
|
||||||
def _write_worker_loop(self) -> None:
|
|
||||||
"""Main loop for the write worker thread."""
|
|
||||||
|
|
||||||
while True:
|
|
||||||
# Process deferred tasks first
|
|
||||||
self._process_deferred_tasks()
|
|
||||||
|
|
||||||
# Get new task
|
|
||||||
try:
|
|
||||||
task = self._write_task_q.get(timeout=0.01)
|
|
||||||
except Empty:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check if remote blocks are ready
|
|
||||||
if not self._is_remote_ready(task):
|
|
||||||
# task.retry_count += 1
|
|
||||||
self._deferred_tasks.append(task)
|
|
||||||
# logger.debug(
|
|
||||||
# "Deferred task for request %s (retry %d)",
|
|
||||||
# task.request_id, task.retry_count
|
|
||||||
# )
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Execute the task
|
|
||||||
|
|
||||||
self._execute_write_task(task)
|
|
||||||
|
|
||||||
def _process_deferred_tasks(self) -> None:
|
|
||||||
"""Process tasks that were previously deferred."""
|
|
||||||
if not self._deferred_tasks:
|
|
||||||
return
|
|
||||||
|
|
||||||
still_deferred: list[WriteTask] = []
|
|
||||||
for task in self._deferred_tasks:
|
|
||||||
if self._is_remote_ready(task):
|
|
||||||
self._execute_write_task(task)
|
|
||||||
else:
|
|
||||||
still_deferred.append(task)
|
|
||||||
|
|
||||||
self._deferred_tasks = still_deferred
|
|
||||||
|
|
||||||
def _is_remote_ready(self, task: WriteTask) -> bool:
|
|
||||||
"""Check if remote blocks are allocated for this task.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task: The write task
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if remote blocks are ready
|
|
||||||
"""
|
|
||||||
return (
|
|
||||||
task.request_id in self.worker.moriio_wrapper.done_remote_allocate_req_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_remote_alloc_info(self, request_id: str) -> RemoteAllocInfo:
|
|
||||||
"""Get remote allocation info for a request.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request_id: The request ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Remote allocation information
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
KeyError: If allocation info is missing
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return self.worker.moriio_wrapper.done_remote_allocate_req_dict[request_id]
|
|
||||||
except KeyError as e:
|
|
||||||
raise KeyError(
|
|
||||||
f"Remote allocation info missing for request {request_id}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
def _execute_write_task(self, task: WriteTask) -> None:
|
|
||||||
"""Execute a single write task.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task: The write task to execute
|
|
||||||
|
|
||||||
"""
|
|
||||||
# Get remote allocation info
|
|
||||||
request_info = self._get_remote_alloc_info(task.request_id)
|
|
||||||
|
|
||||||
if request_info.block_ids is None:
|
|
||||||
logger.debug("Request %s remote block IDs not ready", task.request_id)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Wait for CUDA event
|
|
||||||
# The attention computation of the current layer cannot
|
|
||||||
# overlap with the kv transfer task,
|
|
||||||
# otherwise it will cause precision issues.
|
|
||||||
# This event is used to synchronize the kv transfer and computation tasks.
|
|
||||||
task.event.synchronize()
|
|
||||||
|
|
||||||
# Update engine ID with DP rank
|
|
||||||
task.dst_engine_id = self.worker.get_engine_name_with_dp(
|
|
||||||
task.dst_engine_id, request_info.decode_dp_rank
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get or create sessions
|
|
||||||
sessions, remote_moriio_meta = self.worker._get_built_session(
|
|
||||||
task.dst_engine_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prepare transfer plan
|
|
||||||
plan = self._prepare_transfer_plan(task, request_info, remote_moriio_meta)
|
|
||||||
|
|
||||||
# Execute transfer
|
|
||||||
self._do_layer_write(plan, sessions)
|
|
||||||
|
|
||||||
# Finalize if all layers complete
|
|
||||||
self._finalize_if_complete(task, request_info)
|
|
||||||
|
|
||||||
def _prepare_transfer_plan(
|
|
||||||
self,
|
|
||||||
task: WriteTask,
|
|
||||||
request_info: RemoteAllocInfo,
|
|
||||||
remote_moriio_meta: MoRIIOAgentMetadata,
|
|
||||||
) -> LayerTransferPlan:
|
|
||||||
"""Prepare the transfer plan for a layer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task: The write task
|
|
||||||
request_info: Remote allocation information
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The transfer plan
|
|
||||||
"""
|
|
||||||
# Compute offsets if not cached
|
|
||||||
if request_info.transfer_offset is None:
|
|
||||||
offsets = self.worker._compute_block_transfer_offsets(
|
|
||||||
task.layer_name,
|
|
||||||
task.local_block_ids,
|
|
||||||
request_info.block_ids,
|
|
||||||
remote_moriio_meta,
|
|
||||||
)
|
|
||||||
request_info.transfer_offset = offsets
|
|
||||||
|
|
||||||
# Get session index
|
|
||||||
layer_names = list(self.worker.layer_name_to_local_kv_cache_metadata.keys())
|
|
||||||
sess_idx = layer_names.index(task.layer_name)
|
|
||||||
|
|
||||||
local_off, remote_off, sizes = request_info.transfer_offset
|
|
||||||
|
|
||||||
return LayerTransferPlan(
|
|
||||||
request_id=task.request_id,
|
|
||||||
layer_name=task.layer_name,
|
|
||||||
sess_idx=sess_idx,
|
|
||||||
transfer_local_offsets=local_off,
|
|
||||||
transfer_remote_offsets=remote_off,
|
|
||||||
transfer_sizes=sizes,
|
|
||||||
use_batch=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _do_layer_write(self, plan: LayerTransferPlan, sessions: list) -> None:
|
|
||||||
"""Perform the actual layer write.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
plan: The transfer plan
|
|
||||||
sessions: List of transfer sessions
|
|
||||||
"""
|
|
||||||
if plan.use_batch:
|
|
||||||
self.worker.moriio_wrapper.write_remote_data(
|
|
||||||
plan.transfer_sizes,
|
|
||||||
plan.transfer_local_offsets,
|
|
||||||
plan.transfer_remote_offsets,
|
|
||||||
sessions[plan.sess_idx],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
for i in range(len(plan.transfer_local_offsets)):
|
|
||||||
self.worker.moriio_wrapper.write_remote_data_single(
|
|
||||||
plan.transfer_sizes[i],
|
|
||||||
plan.transfer_local_offsets[i],
|
|
||||||
plan.transfer_remote_offsets[i],
|
|
||||||
plan.sess_idx,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _finalize_if_complete(
|
|
||||||
self, task: WriteTask, request_info: RemoteAllocInfo
|
|
||||||
) -> None:
|
|
||||||
"""Finalize transfer if all layers are complete.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task: The write task
|
|
||||||
request_info: Remote allocation information
|
|
||||||
"""
|
|
||||||
request_info.writes_done += 1
|
|
||||||
|
|
||||||
if request_info.writes_done >= self.worker.num_layers:
|
|
||||||
# Wait for transfer to complete
|
|
||||||
self.worker.moriio_wrapper.waiting_for_transfer_complete()
|
|
||||||
|
|
||||||
remote_port = task.remote_notify_port + get_port_offset(
|
|
||||||
request_info.decode_dp_rank, self.worker.tp_rank
|
|
||||||
)
|
|
||||||
# Consider using RDMA immediate data in decode side
|
|
||||||
# to eliminate the need for this notification.
|
|
||||||
# Consider including the first gen token from prefill in the notification
|
|
||||||
|
|
||||||
# Send completion notification
|
|
||||||
self.worker.moriio_wrapper.send_notify(
|
|
||||||
task.request_id, task.remote_ip, remote_port
|
|
||||||
)
|
|
||||||
# mark request as done, then we can free the blocks
|
|
||||||
with self.worker.moriio_wrapper.lock:
|
|
||||||
self.worker.moriio_wrapper.done_req_ids.append(task.request_id)
|
|
||||||
del self.worker.moriio_wrapper.done_remote_allocate_req_dict[
|
|
||||||
task.request_id
|
|
||||||
]
|
|
||||||
logger.debug(
|
|
||||||
"Completed transfer for request %s, notified port %d",
|
|
||||||
task.request_id,
|
|
||||||
remote_port,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MoRIIOWrapper:
|
|
||||||
"""Wrapper for MoRIIO engine operations.
|
|
||||||
|
|
||||||
Handles both producer and consumer roles for KV cache transfers.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
moriio_engine: MoRIIO engine instance
|
|
||||||
tp_rank: Tensor parallel rank
|
|
||||||
dp_rank: Data parallel rank
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
moriio_engine: Optional["IOEngine"] = None,
|
|
||||||
tp_rank: int = 0,
|
|
||||||
dp_rank: int = 0,
|
|
||||||
):
|
|
||||||
self.tp_rank = tp_rank
|
|
||||||
self.dp_rank = dp_rank
|
|
||||||
self.moriio_engine = moriio_engine
|
|
||||||
self.remote_memory_metadata = None
|
|
||||||
self.local_memory_registered = False
|
|
||||||
self.local_memory_metadata = None
|
|
||||||
self.transfer_status: list[Any] = []
|
|
||||||
self.remote_engine_ip: str | None = None
|
|
||||||
self.notify_port: int | None = None
|
|
||||||
self.lock = threading.Lock()
|
|
||||||
self.done_req_ids: list[str] = []
|
|
||||||
self.done_remote_allocate_req_dict: dict[str, RemoteAllocInfo] = {}
|
|
||||||
self.done_write_cache_req_ids: list[str] = []
|
|
||||||
self.notify_thread: threading.Thread | None = None
|
|
||||||
self.sessions: list[IOEngine.Session] = []
|
|
||||||
self.paths: dict[str, zmq.Socket] = {}
|
|
||||||
|
|
||||||
def set_moriio_engine(self, moriio_engine):
|
|
||||||
assert moriio_engine is not None, (
|
|
||||||
"You Cannot pass None engine to MoRIIOWrapper!"
|
|
||||||
)
|
|
||||||
self.moriio_engine = moriio_engine
|
|
||||||
|
|
||||||
def set_backend_type(self, backend_type):
|
|
||||||
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
|
||||||
qp_per_transfer = envs.VLLM_MORIIO_QP_PER_TRANSFER
|
|
||||||
post_batch_size = envs.VLLM_MORIIO_POST_BATCH_SIZE
|
|
||||||
num_worker_threads = envs.VLLM_MORIIO_NUM_WORKERS
|
|
||||||
poll_mode = PollCqMode.POLLING
|
|
||||||
rdma_cfg = RdmaBackendConfig(
|
|
||||||
qp_per_transfer,
|
|
||||||
post_batch_size,
|
|
||||||
num_worker_threads,
|
|
||||||
poll_mode,
|
|
||||||
)
|
|
||||||
self.moriio_engine.create_backend(backend_type, rdma_cfg)
|
|
||||||
|
|
||||||
def get_agent_metadata(self):
|
|
||||||
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
|
||||||
engine_metadata = self.moriio_engine.get_engine_desc()
|
|
||||||
engine_metadata_packed = engine_metadata.pack()
|
|
||||||
return engine_metadata_packed
|
|
||||||
|
|
||||||
def register_remote_engine(self, remote_packed_engine_metadata):
|
|
||||||
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
|
||||||
consumer_engine_metadata = EngineDesc.unpack(remote_packed_engine_metadata)
|
|
||||||
self.moriio_engine.register_remote_engine(consumer_engine_metadata)
|
|
||||||
return consumer_engine_metadata.key
|
|
||||||
|
|
||||||
def register_local_tensor(self, tensor: torch.Tensor):
|
|
||||||
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
|
||||||
try:
|
|
||||||
self.local_memory_metadata = self.moriio_engine.register_torch_tensor(
|
|
||||||
tensor
|
|
||||||
)
|
|
||||||
assert self.local_memory_metadata is not None, (
|
|
||||||
"register_torch_tensor returned None"
|
|
||||||
)
|
|
||||||
local_memory_metadata_packed = self.local_memory_metadata.pack()
|
|
||||||
except Exception as e:
|
|
||||||
raise MoRIIOError(f"Failed to register local memory: {e}") from e
|
|
||||||
self.local_memory_registered = True
|
|
||||||
return local_memory_metadata_packed
|
|
||||||
|
|
||||||
def get_unpack_memory_metadata(self, packed_memory_metadata):
|
|
||||||
return MemoryDesc.unpack(packed_memory_metadata)
|
|
||||||
|
|
||||||
def build_session(self, local_memory_metadata, remote_memory_metadata):
|
|
||||||
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
|
||||||
return self.moriio_engine.create_session(
|
|
||||||
local_memory_metadata, remote_memory_metadata
|
|
||||||
)
|
|
||||||
|
|
||||||
def read_remote_data(
|
|
||||||
self, transfer_size_byte, local_offset=0, remote_offset=0, session=None
|
|
||||||
):
|
|
||||||
assert self.local_memory_registered, "You have not register local memory data!"
|
|
||||||
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
|
||||||
transfer_status = session.batch_read(
|
|
||||||
local_offset,
|
|
||||||
remote_offset,
|
|
||||||
transfer_size_byte,
|
|
||||||
self.moriio_engine.allocate_transfer_uid(),
|
|
||||||
)
|
|
||||||
|
|
||||||
return transfer_status
|
|
||||||
|
|
||||||
def write_remote_data(
|
|
||||||
self, transfer_size_byte, local_offset=0, remote_offset=0, session=None
|
|
||||||
):
|
|
||||||
assert self.local_memory_registered, "You have not register local memory data!"
|
|
||||||
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
|
||||||
write_uid = self.moriio_engine.allocate_transfer_uid()
|
|
||||||
|
|
||||||
transfer_status = session.batch_write(
|
|
||||||
local_offset, remote_offset, transfer_size_byte, write_uid
|
|
||||||
)
|
|
||||||
with self.lock:
|
|
||||||
self.transfer_status.append(transfer_status)
|
|
||||||
|
|
||||||
def write_remote_data_single(
|
|
||||||
self, transfer_size_byte, local_offset=0, remote_offset=0, sess_idx=0
|
|
||||||
):
|
|
||||||
assert self.local_memory_registered, "You have not register local memory data!"
|
|
||||||
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
|
||||||
transfer_status = self.sessions[sess_idx].write(
|
|
||||||
local_offset,
|
|
||||||
remote_offset,
|
|
||||||
transfer_size_byte,
|
|
||||||
self.moriio_engine.allocate_transfer_uid(),
|
|
||||||
)
|
|
||||||
with self.lock:
|
|
||||||
self.transfer_status.append(transfer_status)
|
|
||||||
|
|
||||||
def waiting_for_transfer_complete(self):
|
|
||||||
if not self.transfer_status:
|
|
||||||
return
|
|
||||||
|
|
||||||
transfers_to_wait = []
|
|
||||||
with self.lock:
|
|
||||||
transfers_to_wait = self.transfer_status[:]
|
|
||||||
self.transfer_status.clear()
|
|
||||||
|
|
||||||
for status in transfers_to_wait:
|
|
||||||
try:
|
|
||||||
status.Wait()
|
|
||||||
if not status.Succeeded():
|
|
||||||
logger.error(
|
|
||||||
"Transfer failed: %s, Code: %s", status.Message(), status.Code()
|
|
||||||
)
|
|
||||||
raise TransferError("MoRIIO transfer failed!")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Transfer %s failed: %s", status, e)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def async_wait_reqid(self):
|
|
||||||
assert self.notify_port is not None, "Notify port cannot be None"
|
|
||||||
|
|
||||||
if self.notify_thread is not None:
|
|
||||||
return
|
|
||||||
|
|
||||||
def _async_wait():
|
|
||||||
host = "*"
|
|
||||||
path = make_zmq_path("tcp", host, self.notify_port)
|
|
||||||
logger.info("Node starting to listen notify from path = %s", path)
|
|
||||||
|
|
||||||
with zmq_ctx(zmq.ROUTER, path) as sock:
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
identity, msg = sock.recv_multipart()
|
|
||||||
self._handle_message(msg)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Error processing message: %s", e)
|
|
||||||
raise HandshakeError(f"Error processing message: {e}") from e
|
|
||||||
|
|
||||||
self.notify_thread = threading.Thread(
|
|
||||||
target=_async_wait, daemon=True, name="moriio-notify-listener"
|
|
||||||
)
|
|
||||||
self.notify_thread.start()
|
|
||||||
|
|
||||||
def _handle_message(self, msg: bytes):
|
|
||||||
"""Handles incoming messages from remote nodes."""
|
|
||||||
# Handles incoming remote messages:
|
|
||||||
# Prefill Role:
|
|
||||||
# [write] mode: receives block information (allocation)
|
|
||||||
# [read] mode: receives block release messages from decode side
|
|
||||||
# Decode Role:
|
|
||||||
# [write] mode: receives KV cache write completion notifications
|
|
||||||
handled = False
|
|
||||||
try:
|
|
||||||
data = msgpack.loads(msg)
|
|
||||||
if isinstance(data, dict) and "req_id" in data:
|
|
||||||
self._handle_structured_message(data)
|
|
||||||
|
|
||||||
return
|
|
||||||
except (msgpack.exceptions.ExtraData, msgpack.exceptions.UnpackException):
|
|
||||||
logger.debug("Failed to decode msgpack message, will try as string")
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
msg_str = msg.decode("UTF-8")
|
|
||||||
if msg_str.startswith(MoRIIOConstants.COMPLETION_PREFIX):
|
|
||||||
self._handle_completion_message(msg_str)
|
|
||||||
handled = True
|
|
||||||
except UnicodeDecodeError:
|
|
||||||
logger.warning("Received non-UTF8 message: %s", msg_str)
|
|
||||||
if not handled:
|
|
||||||
raise MoRIIOError(f"Unhandled message format: {msg_str}")
|
|
||||||
|
|
||||||
def _handle_structured_message(self, data: dict):
|
|
||||||
assert get_role() == ROLE.PRODUCER, "Only prefill can get block messages"
|
|
||||||
req_id = data["req_id"]
|
|
||||||
block_notify_list = data.get("block_notify_list", [])
|
|
||||||
decode_dp_rank = data.get("decode_rank", 0)
|
|
||||||
assert len(block_notify_list) > 0, (
|
|
||||||
"block_notify_list cannot be empty in remote allocate message"
|
|
||||||
)
|
|
||||||
|
|
||||||
with self.lock:
|
|
||||||
self.done_remote_allocate_req_dict[req_id] = RemoteAllocInfo(
|
|
||||||
block_ids=block_notify_list, decode_dp_rank=decode_dp_rank
|
|
||||||
)
|
|
||||||
|
|
||||||
def _handle_completion_message(self, msg: str):
|
|
||||||
with self.lock:
|
|
||||||
if get_role() == ROLE.PRODUCER:
|
|
||||||
self.done_req_ids.append(msg)
|
|
||||||
else:
|
|
||||||
self.done_write_cache_req_ids.append(msg)
|
|
||||||
|
|
||||||
def send_notify(self, req_ids, remote_ip, remote_port):
|
|
||||||
if not remote_ip or not remote_port:
|
|
||||||
logger.warning("Missing remote_ip or remote_port for notification")
|
|
||||||
return
|
|
||||||
|
|
||||||
path = make_zmq_path("tcp", remote_ip, remote_port)
|
|
||||||
|
|
||||||
if path not in self.paths:
|
|
||||||
ctx = zmq.Context.instance()
|
|
||||||
sock = make_zmq_socket(
|
|
||||||
ctx=ctx, path=path, socket_type=zmq.DEALER, bind=False
|
|
||||||
)
|
|
||||||
self.paths[path] = sock
|
|
||||||
|
|
||||||
req_list = req_ids if isinstance(req_ids, list) else [req_ids]
|
|
||||||
|
|
||||||
sock = self.paths[path]
|
|
||||||
try:
|
|
||||||
for req_id in req_list:
|
|
||||||
if not isinstance(req_id, str):
|
|
||||||
logger.warning(
|
|
||||||
"Invalid req_id type: %s, expected str", type(req_id)
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
sock.send(req_id.encode("utf-8"))
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Failed to send notification to %s: %s", path, e)
|
|
||||||
self.paths.pop(path, None)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def pop_finished_req_ids(self):
|
|
||||||
# producer invocation: get the set of completed requests at the decode
|
|
||||||
with self.lock:
|
|
||||||
done_send = set(self.done_req_ids)
|
|
||||||
self.done_req_ids = []
|
|
||||||
return done_send
|
|
||||||
|
|
||||||
def pop_finished_write_req_ids(self):
|
|
||||||
# Call the consumer in write mode to get the collection after write completion
|
|
||||||
with self.lock:
|
|
||||||
done_write_cache = set(self.done_write_cache_req_ids)
|
|
||||||
self.done_write_cache_req_ids = []
|
|
||||||
return done_write_cache
|
|
||||||
|
|
||||||
def shutdown(self):
|
|
||||||
logger.debug("Closing MoRIIOWrapper and cleaning up ZMQ sockets")
|
|
||||||
for path, sock in self.paths.items():
|
|
||||||
try:
|
|
||||||
sock.close(linger=0)
|
|
||||||
logger.debug("Closed ZMQ socket for path: %s", path)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning("Error closing ZMQ socket for path %s: %s", path, e)
|
|
||||||
self.paths.clear()
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ReqMeta:
|
|
||||||
"""Metadata for a single request."""
|
|
||||||
|
|
||||||
local_block_ids: list[int]
|
|
||||||
remote_block_ids: list[int]
|
|
||||||
remote_host: str
|
|
||||||
remote_port: int
|
|
||||||
remote_handshake_port: int
|
|
||||||
remote_notify_port: int
|
|
||||||
remote_engine_id: str
|
|
||||||
tp_size: int
|
|
||||||
remote_dp_size: int
|
|
||||||
|
|
||||||
|
|
||||||
class MoRIIOConnectorMetadata(KVConnectorMetadata):
|
|
||||||
def __init__(self):
|
|
||||||
self.reqs_to_recv: dict[ReqId, ReqMeta] = {}
|
|
||||||
self.reqs_to_save: dict[ReqId, ReqMeta] = {}
|
|
||||||
self.reqs_to_send: dict[ReqId, float] = {}
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return_str = ""
|
|
||||||
for req_id, req_meta in self.reqs_to_recv.items():
|
|
||||||
return_str += (
|
|
||||||
f"{req_id = },{req_meta.local_block_ids = },"
|
|
||||||
f"{req_meta.remote_host = },{req_meta.remote_port = }"
|
|
||||||
f"{req_meta.remote_engine_id = },{req_meta.tp_size = }"
|
|
||||||
)
|
|
||||||
return_str = f"MoRIIOConnectorMetadata:reqs_to_recv:{return_str},"
|
|
||||||
|
|
||||||
for req_id, expiry in self.reqs_to_send.items():
|
|
||||||
return_str += f"{req_id = },{expiry = }"
|
|
||||||
return_str = f"MoRIIOConnectorMetadata:reqs_to_send:{return_str},"
|
|
||||||
return return_str
|
|
||||||
|
|
||||||
def add_new_req(
|
|
||||||
self,
|
|
||||||
request_id: ReqId,
|
|
||||||
local_block_ids: list[int],
|
|
||||||
kv_transfer_params: dict[str, Any],
|
|
||||||
write_mode=False,
|
|
||||||
):
|
|
||||||
_req = ReqMeta(
|
|
||||||
local_block_ids=local_block_ids,
|
|
||||||
remote_block_ids=kv_transfer_params["remote_block_ids"],
|
|
||||||
remote_engine_id=kv_transfer_params["remote_engine_id"],
|
|
||||||
remote_host=kv_transfer_params["remote_host"],
|
|
||||||
remote_port=kv_transfer_params["remote_port"],
|
|
||||||
remote_handshake_port=kv_transfer_params["remote_handshake_port"],
|
|
||||||
remote_notify_port=kv_transfer_params["remote_notify_port"],
|
|
||||||
tp_size=kv_transfer_params.get("tp_size", 1),
|
|
||||||
remote_dp_size=kv_transfer_params.get("remote_dp_size", 1),
|
|
||||||
)
|
|
||||||
if write_mode:
|
|
||||||
self.reqs_to_save[request_id] = _req
|
|
||||||
else:
|
|
||||||
self.reqs_to_recv[request_id] = _req
|
|
||||||
|
|
||||||
|
|
||||||
class MoRIIOConnector(KVConnectorBase_V1):
|
class MoRIIOConnector(KVConnectorBase_V1):
|
||||||
@ -1371,7 +561,7 @@ class MoRIIOConnectorWorker:
|
|||||||
"""Implementation of Worker side methods"""
|
"""Implementation of Worker side methods"""
|
||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
||||||
if not MoRIIO_enabled:
|
if not is_moriio_available():
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"MoRIIO is not available. Please ensure the 'mori' package "
|
"MoRIIO is not available. Please ensure the 'mori' package "
|
||||||
"is installed and properly configured."
|
"is installed and properly configured."
|
||||||
@ -2323,21 +1513,3 @@ class MoRIIOConnectorWorker:
|
|||||||
remote_host,
|
remote_host,
|
||||||
str(remote_notify_port + self.tp_rank),
|
str(remote_notify_port + self.tp_rank),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]:
|
|
||||||
"""Context manager for a ZMQ socket"""
|
|
||||||
|
|
||||||
if socket_type not in (zmq.ROUTER, zmq.REQ, zmq.DEALER):
|
|
||||||
raise ValueError(f"Unexpected socket type: {socket_type}")
|
|
||||||
|
|
||||||
ctx: zmq.Context | None = None
|
|
||||||
try:
|
|
||||||
ctx = zmq.Context() # type: ignore[attr-defined]
|
|
||||||
yield make_zmq_socket(
|
|
||||||
ctx=ctx, path=addr, socket_type=socket_type, bind=socket_type == zmq.ROUTER
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
if ctx is not None:
|
|
||||||
ctx.destroy(linger=0)
|
|
||||||
@ -0,0 +1,607 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import threading
|
||||||
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
from weakref import ref as weakref_ref
|
||||||
|
|
||||||
|
import msgpack
|
||||||
|
import torch
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
from vllm import envs
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils.network_utils import (
|
||||||
|
make_zmq_path,
|
||||||
|
make_zmq_socket,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
pass
|
||||||
|
|
||||||
|
from queue import Empty, Queue
|
||||||
|
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_common import (
|
||||||
|
ROLE,
|
||||||
|
HandshakeError,
|
||||||
|
LayerTransferPlan,
|
||||||
|
MoRIIOAgentMetadata,
|
||||||
|
MoRIIOConstants,
|
||||||
|
MoRIIOError,
|
||||||
|
RemoteAllocInfo,
|
||||||
|
TransferError,
|
||||||
|
WriteTask,
|
||||||
|
get_port_offset,
|
||||||
|
get_role,
|
||||||
|
zmq_ctx,
|
||||||
|
)
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.moriio.moriio_connector import (
|
||||||
|
MoRIIOConnectorWorker,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
try:
|
||||||
|
from mori.io import (
|
||||||
|
EngineDesc,
|
||||||
|
IOEngine,
|
||||||
|
MemoryDesc,
|
||||||
|
PollCqMode,
|
||||||
|
RdmaBackendConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("MoRIIO is available")
|
||||||
|
except ImportError:
|
||||||
|
logger.error("MoRIIO is not available")
|
||||||
|
|
||||||
|
|
||||||
|
"""Write task execution logic for MoRIIO connector."""
|
||||||
|
|
||||||
|
|
||||||
|
class MoRIIOWriter:
|
||||||
|
"""Handles write operations for KV cache transfers.
|
||||||
|
Implements distributed KV cache transfer using the MoRIIO library
|
||||||
|
for RDMA-based communication between prefill and decode instances."""
|
||||||
|
|
||||||
|
def __init__(self, worker: "MoRIIOConnectorWorker"):
|
||||||
|
"""Initialize the writer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
worker: Reference to the parent worker
|
||||||
|
"""
|
||||||
|
self._worker_ref: weakref_ref[MoRIIOConnectorWorker] = weakref_ref(worker)
|
||||||
|
self._write_task_q: Queue[WriteTask] = Queue()
|
||||||
|
self._write_worker_started = False
|
||||||
|
self._write_worker_lock = threading.Lock()
|
||||||
|
self._deferred_tasks: list[WriteTask] = []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def worker(self) -> "MoRIIOConnectorWorker":
|
||||||
|
"""Get the worker instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The parent worker instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If worker has been garbage collected
|
||||||
|
"""
|
||||||
|
worker = self._worker_ref()
|
||||||
|
if worker is None:
|
||||||
|
raise RuntimeError("Parent worker has been garbage collected")
|
||||||
|
return worker
|
||||||
|
|
||||||
|
def ensure_worker_started(self) -> None:
|
||||||
|
"""Ensure the background write worker is running."""
|
||||||
|
if self._write_worker_started:
|
||||||
|
return
|
||||||
|
self._write_worker_started = True
|
||||||
|
with self._write_worker_lock:
|
||||||
|
thread = threading.Thread(
|
||||||
|
target=self._write_worker_loop, daemon=True, name="moriio-write-worker"
|
||||||
|
)
|
||||||
|
thread.start()
|
||||||
|
logger.info("Started MoRIIO write worker thread")
|
||||||
|
|
||||||
|
def schedule_write(self, task: WriteTask) -> None:
|
||||||
|
"""Schedule a write task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: The write task to schedule
|
||||||
|
"""
|
||||||
|
self.ensure_worker_started()
|
||||||
|
self._write_task_q.put(task)
|
||||||
|
|
||||||
|
def _write_worker_loop(self) -> None:
|
||||||
|
"""Main loop for the write worker thread."""
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# Process deferred tasks first
|
||||||
|
self._process_deferred_tasks()
|
||||||
|
|
||||||
|
# Get new task
|
||||||
|
try:
|
||||||
|
task = self._write_task_q.get(timeout=0.01)
|
||||||
|
except Empty:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if remote blocks are ready
|
||||||
|
if not self._is_remote_ready(task):
|
||||||
|
# task.retry_count += 1
|
||||||
|
self._deferred_tasks.append(task)
|
||||||
|
# logger.debug(
|
||||||
|
# "Deferred task for request %s (retry %d)",
|
||||||
|
# task.request_id, task.retry_count
|
||||||
|
# )
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Execute the task
|
||||||
|
|
||||||
|
self._execute_write_task(task)
|
||||||
|
|
||||||
|
def _process_deferred_tasks(self) -> None:
|
||||||
|
"""Process tasks that were previously deferred."""
|
||||||
|
if not self._deferred_tasks:
|
||||||
|
return
|
||||||
|
|
||||||
|
still_deferred: list[WriteTask] = []
|
||||||
|
for task in self._deferred_tasks:
|
||||||
|
if self._is_remote_ready(task):
|
||||||
|
self._execute_write_task(task)
|
||||||
|
else:
|
||||||
|
still_deferred.append(task)
|
||||||
|
|
||||||
|
self._deferred_tasks = still_deferred
|
||||||
|
|
||||||
|
def _is_remote_ready(self, task: WriteTask) -> bool:
|
||||||
|
"""Check if remote blocks are allocated for this task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: The write task
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if remote blocks are ready
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
task.request_id in self.worker.moriio_wrapper.done_remote_allocate_req_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_remote_alloc_info(self, request_id: str) -> RemoteAllocInfo:
|
||||||
|
"""Get remote allocation info for a request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_id: The request ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Remote allocation information
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If allocation info is missing
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return self.worker.moriio_wrapper.done_remote_allocate_req_dict[request_id]
|
||||||
|
except KeyError as e:
|
||||||
|
raise KeyError(
|
||||||
|
f"Remote allocation info missing for request {request_id}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
def _execute_write_task(self, task: WriteTask) -> None:
|
||||||
|
"""Execute a single write task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: The write task to execute
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Get remote allocation info
|
||||||
|
request_info = self._get_remote_alloc_info(task.request_id)
|
||||||
|
|
||||||
|
if request_info.block_ids is None:
|
||||||
|
logger.debug("Request %s remote block IDs not ready", task.request_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Wait for CUDA event
|
||||||
|
# The attention computation of the current layer cannot
|
||||||
|
# overlap with the kv transfer task,
|
||||||
|
# otherwise it will cause precision issues.
|
||||||
|
# This event is used to synchronize the kv transfer and computation tasks.
|
||||||
|
task.event.synchronize()
|
||||||
|
|
||||||
|
# Update engine ID with DP rank
|
||||||
|
task.dst_engine_id = self.worker.get_engine_name_with_dp(
|
||||||
|
task.dst_engine_id, request_info.decode_dp_rank
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get or create sessions
|
||||||
|
sessions, remote_moriio_meta = self.worker._get_built_session(
|
||||||
|
task.dst_engine_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare transfer plan
|
||||||
|
plan = self._prepare_transfer_plan(task, request_info, remote_moriio_meta)
|
||||||
|
|
||||||
|
# Execute transfer
|
||||||
|
self._do_layer_write(plan, sessions)
|
||||||
|
|
||||||
|
# Finalize if all layers complete
|
||||||
|
self._finalize_if_complete(task, request_info)
|
||||||
|
|
||||||
|
def _prepare_transfer_plan(
|
||||||
|
self,
|
||||||
|
task: WriteTask,
|
||||||
|
request_info: RemoteAllocInfo,
|
||||||
|
remote_moriio_meta: MoRIIOAgentMetadata,
|
||||||
|
) -> LayerTransferPlan:
|
||||||
|
"""Prepare the transfer plan for a layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: The write task
|
||||||
|
request_info: Remote allocation information
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The transfer plan
|
||||||
|
"""
|
||||||
|
# Compute offsets if not cached
|
||||||
|
if request_info.transfer_offset is None:
|
||||||
|
offsets = self.worker._compute_block_transfer_offsets(
|
||||||
|
task.layer_name,
|
||||||
|
task.local_block_ids,
|
||||||
|
request_info.block_ids,
|
||||||
|
remote_moriio_meta,
|
||||||
|
)
|
||||||
|
request_info.transfer_offset = offsets
|
||||||
|
|
||||||
|
# Get session index
|
||||||
|
layer_names = list(self.worker.layer_name_to_local_kv_cache_metadata.keys())
|
||||||
|
sess_idx = layer_names.index(task.layer_name)
|
||||||
|
|
||||||
|
local_off, remote_off, sizes = request_info.transfer_offset
|
||||||
|
|
||||||
|
return LayerTransferPlan(
|
||||||
|
request_id=task.request_id,
|
||||||
|
layer_name=task.layer_name,
|
||||||
|
sess_idx=sess_idx,
|
||||||
|
transfer_local_offsets=local_off,
|
||||||
|
transfer_remote_offsets=remote_off,
|
||||||
|
transfer_sizes=sizes,
|
||||||
|
use_batch=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _do_layer_write(self, plan: LayerTransferPlan, sessions: list) -> None:
|
||||||
|
"""Perform the actual layer write.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plan: The transfer plan
|
||||||
|
sessions: List of transfer sessions
|
||||||
|
"""
|
||||||
|
if plan.use_batch:
|
||||||
|
self.worker.moriio_wrapper.write_remote_data(
|
||||||
|
plan.transfer_sizes,
|
||||||
|
plan.transfer_local_offsets,
|
||||||
|
plan.transfer_remote_offsets,
|
||||||
|
sessions[plan.sess_idx],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
for i in range(len(plan.transfer_local_offsets)):
|
||||||
|
self.worker.moriio_wrapper.write_remote_data_single(
|
||||||
|
plan.transfer_sizes[i],
|
||||||
|
plan.transfer_local_offsets[i],
|
||||||
|
plan.transfer_remote_offsets[i],
|
||||||
|
plan.sess_idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _finalize_if_complete(
|
||||||
|
self, task: WriteTask, request_info: RemoteAllocInfo
|
||||||
|
) -> None:
|
||||||
|
"""Finalize transfer if all layers are complete.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: The write task
|
||||||
|
request_info: Remote allocation information
|
||||||
|
"""
|
||||||
|
request_info.writes_done += 1
|
||||||
|
|
||||||
|
if request_info.writes_done >= self.worker.num_layers:
|
||||||
|
# Wait for transfer to complete
|
||||||
|
self.worker.moriio_wrapper.waiting_for_transfer_complete()
|
||||||
|
|
||||||
|
remote_port = task.remote_notify_port + get_port_offset(
|
||||||
|
request_info.decode_dp_rank, self.worker.tp_rank
|
||||||
|
)
|
||||||
|
# Consider using RDMA immediate data in decode side
|
||||||
|
# to eliminate the need for this notification.
|
||||||
|
# Consider including the first gen token from prefill in the notification
|
||||||
|
|
||||||
|
# Send completion notification
|
||||||
|
self.worker.moriio_wrapper.send_notify(
|
||||||
|
task.request_id, task.remote_ip, remote_port
|
||||||
|
)
|
||||||
|
# mark request as done, then we can free the blocks
|
||||||
|
with self.worker.moriio_wrapper.lock:
|
||||||
|
self.worker.moriio_wrapper.done_req_ids.append(task.request_id)
|
||||||
|
del self.worker.moriio_wrapper.done_remote_allocate_req_dict[
|
||||||
|
task.request_id
|
||||||
|
]
|
||||||
|
logger.debug(
|
||||||
|
"Completed transfer for request %s, notified port %d",
|
||||||
|
task.request_id,
|
||||||
|
remote_port,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MoRIIOWrapper:
|
||||||
|
"""Wrapper for MoRIIO engine operations.
|
||||||
|
|
||||||
|
Handles both producer and consumer roles for KV cache transfers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
moriio_engine: MoRIIO engine instance
|
||||||
|
tp_rank: Tensor parallel rank
|
||||||
|
dp_rank: Data parallel rank
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
moriio_engine: Optional["IOEngine"] = None,
|
||||||
|
tp_rank: int = 0,
|
||||||
|
dp_rank: int = 0,
|
||||||
|
):
|
||||||
|
self.tp_rank = tp_rank
|
||||||
|
self.dp_rank = dp_rank
|
||||||
|
self.moriio_engine = moriio_engine
|
||||||
|
self.remote_memory_metadata = None
|
||||||
|
self.local_memory_registered = False
|
||||||
|
self.local_memory_metadata = None
|
||||||
|
self.transfer_status: list[Any] = []
|
||||||
|
self.remote_engine_ip: str | None = None
|
||||||
|
self.notify_port: int | None = None
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
self.done_req_ids: list[str] = []
|
||||||
|
self.done_remote_allocate_req_dict: dict[str, RemoteAllocInfo] = {}
|
||||||
|
self.done_write_cache_req_ids: list[str] = []
|
||||||
|
self.notify_thread: threading.Thread | None = None
|
||||||
|
self.sessions: list[IOEngine.Session] = []
|
||||||
|
self.paths: dict[str, zmq.Socket] = {}
|
||||||
|
|
||||||
|
def set_moriio_engine(self, moriio_engine):
|
||||||
|
assert moriio_engine is not None, (
|
||||||
|
"You Cannot pass None engine to MoRIIOWrapper!"
|
||||||
|
)
|
||||||
|
self.moriio_engine = moriio_engine
|
||||||
|
|
||||||
|
def set_backend_type(self, backend_type):
|
||||||
|
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
||||||
|
qp_per_transfer = envs.VLLM_MORIIO_QP_PER_TRANSFER
|
||||||
|
post_batch_size = envs.VLLM_MORIIO_POST_BATCH_SIZE
|
||||||
|
num_worker_threads = envs.VLLM_MORIIO_NUM_WORKERS
|
||||||
|
poll_mode = PollCqMode.POLLING
|
||||||
|
rdma_cfg = RdmaBackendConfig(
|
||||||
|
qp_per_transfer,
|
||||||
|
post_batch_size,
|
||||||
|
num_worker_threads,
|
||||||
|
poll_mode,
|
||||||
|
)
|
||||||
|
self.moriio_engine.create_backend(backend_type, rdma_cfg)
|
||||||
|
|
||||||
|
def get_agent_metadata(self):
|
||||||
|
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
||||||
|
engine_metadata = self.moriio_engine.get_engine_desc()
|
||||||
|
engine_metadata_packed = engine_metadata.pack()
|
||||||
|
return engine_metadata_packed
|
||||||
|
|
||||||
|
def register_remote_engine(self, remote_packed_engine_metadata):
|
||||||
|
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
||||||
|
consumer_engine_metadata = EngineDesc.unpack(remote_packed_engine_metadata)
|
||||||
|
self.moriio_engine.register_remote_engine(consumer_engine_metadata)
|
||||||
|
return consumer_engine_metadata.key
|
||||||
|
|
||||||
|
def register_local_tensor(self, tensor: torch.Tensor):
|
||||||
|
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
||||||
|
try:
|
||||||
|
self.local_memory_metadata = self.moriio_engine.register_torch_tensor(
|
||||||
|
tensor
|
||||||
|
)
|
||||||
|
assert self.local_memory_metadata is not None, (
|
||||||
|
"register_torch_tensor returned None"
|
||||||
|
)
|
||||||
|
local_memory_metadata_packed = self.local_memory_metadata.pack()
|
||||||
|
except Exception as e:
|
||||||
|
raise MoRIIOError(f"Failed to register local memory: {e}") from e
|
||||||
|
self.local_memory_registered = True
|
||||||
|
return local_memory_metadata_packed
|
||||||
|
|
||||||
|
def get_unpack_memory_metadata(self, packed_memory_metadata):
|
||||||
|
return MemoryDesc.unpack(packed_memory_metadata)
|
||||||
|
|
||||||
|
def build_session(self, local_memory_metadata, remote_memory_metadata):
|
||||||
|
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
||||||
|
return self.moriio_engine.create_session(
|
||||||
|
local_memory_metadata, remote_memory_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
def read_remote_data(
|
||||||
|
self, transfer_size_byte, local_offset=0, remote_offset=0, session=None
|
||||||
|
):
|
||||||
|
assert self.local_memory_registered, "You have not register local memory data!"
|
||||||
|
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
||||||
|
transfer_status = session.batch_read(
|
||||||
|
local_offset,
|
||||||
|
remote_offset,
|
||||||
|
transfer_size_byte,
|
||||||
|
self.moriio_engine.allocate_transfer_uid(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return transfer_status
|
||||||
|
|
||||||
|
def write_remote_data(
|
||||||
|
self, transfer_size_byte, local_offset=0, remote_offset=0, session=None
|
||||||
|
):
|
||||||
|
assert self.local_memory_registered, "You have not register local memory data!"
|
||||||
|
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
||||||
|
write_uid = self.moriio_engine.allocate_transfer_uid()
|
||||||
|
|
||||||
|
transfer_status = session.batch_write(
|
||||||
|
local_offset, remote_offset, transfer_size_byte, write_uid
|
||||||
|
)
|
||||||
|
with self.lock:
|
||||||
|
self.transfer_status.append(transfer_status)
|
||||||
|
|
||||||
|
def write_remote_data_single(
|
||||||
|
self, transfer_size_byte, local_offset=0, remote_offset=0, sess_idx=0
|
||||||
|
):
|
||||||
|
assert self.local_memory_registered, "You have not register local memory data!"
|
||||||
|
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
|
||||||
|
transfer_status = self.sessions[sess_idx].write(
|
||||||
|
local_offset,
|
||||||
|
remote_offset,
|
||||||
|
transfer_size_byte,
|
||||||
|
self.moriio_engine.allocate_transfer_uid(),
|
||||||
|
)
|
||||||
|
with self.lock:
|
||||||
|
self.transfer_status.append(transfer_status)
|
||||||
|
|
||||||
|
def waiting_for_transfer_complete(self):
|
||||||
|
if not self.transfer_status:
|
||||||
|
return
|
||||||
|
|
||||||
|
transfers_to_wait = []
|
||||||
|
with self.lock:
|
||||||
|
transfers_to_wait = self.transfer_status[:]
|
||||||
|
self.transfer_status.clear()
|
||||||
|
|
||||||
|
for status in transfers_to_wait:
|
||||||
|
try:
|
||||||
|
status.Wait()
|
||||||
|
if not status.Succeeded():
|
||||||
|
logger.error(
|
||||||
|
"Transfer failed: %s, Code: %s", status.Message(), status.Code()
|
||||||
|
)
|
||||||
|
raise TransferError("MoRIIO transfer failed!")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Transfer %s failed: %s", status, e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def async_wait_reqid(self):
|
||||||
|
assert self.notify_port is not None, "Notify port cannot be None"
|
||||||
|
|
||||||
|
if self.notify_thread is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
def _async_wait():
|
||||||
|
host = "*"
|
||||||
|
path = make_zmq_path("tcp", host, self.notify_port)
|
||||||
|
logger.info("Node starting to listen notify from path = %s", path)
|
||||||
|
|
||||||
|
with zmq_ctx(zmq.ROUTER, path) as sock:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
identity, msg = sock.recv_multipart()
|
||||||
|
self._handle_message(msg)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error processing message: %s", e)
|
||||||
|
raise HandshakeError(f"Error processing message: {e}") from e
|
||||||
|
|
||||||
|
self.notify_thread = threading.Thread(
|
||||||
|
target=_async_wait, daemon=True, name="moriio-notify-listener"
|
||||||
|
)
|
||||||
|
self.notify_thread.start()
|
||||||
|
|
||||||
|
def _handle_message(self, msg: bytes):
|
||||||
|
"""Handles incoming messages from remote nodes."""
|
||||||
|
# Handles incoming remote messages:
|
||||||
|
# Prefill Role:
|
||||||
|
# [write] mode: receives block information (allocation)
|
||||||
|
# [read] mode: receives block release messages from decode side
|
||||||
|
# Decode Role:
|
||||||
|
# [write] mode: receives KV cache write completion notifications
|
||||||
|
handled = False
|
||||||
|
try:
|
||||||
|
data = msgpack.loads(msg)
|
||||||
|
if isinstance(data, dict) and "req_id" in data:
|
||||||
|
self._handle_structured_message(data)
|
||||||
|
|
||||||
|
return
|
||||||
|
except (msgpack.exceptions.ExtraData, msgpack.exceptions.UnpackException):
|
||||||
|
logger.debug("Failed to decode msgpack message, will try as string")
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
msg_str = msg.decode("UTF-8")
|
||||||
|
if msg_str.startswith(MoRIIOConstants.COMPLETION_PREFIX):
|
||||||
|
self._handle_completion_message(msg_str)
|
||||||
|
handled = True
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
logger.warning("Received non-UTF8 message: %s", msg_str)
|
||||||
|
if not handled:
|
||||||
|
raise MoRIIOError(f"Unhandled message format: {msg_str}")
|
||||||
|
|
||||||
|
def _handle_structured_message(self, data: dict):
|
||||||
|
assert get_role() == ROLE.PRODUCER, "Only prefill can get block messages"
|
||||||
|
req_id = data["req_id"]
|
||||||
|
block_notify_list = data.get("block_notify_list", [])
|
||||||
|
decode_dp_rank = data.get("decode_rank", 0)
|
||||||
|
assert len(block_notify_list) > 0, (
|
||||||
|
"block_notify_list cannot be empty in remote allocate message"
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.lock:
|
||||||
|
self.done_remote_allocate_req_dict[req_id] = RemoteAllocInfo(
|
||||||
|
block_ids=block_notify_list, decode_dp_rank=decode_dp_rank
|
||||||
|
)
|
||||||
|
|
||||||
|
def _handle_completion_message(self, msg: str):
|
||||||
|
with self.lock:
|
||||||
|
if get_role() == ROLE.PRODUCER:
|
||||||
|
self.done_req_ids.append(msg)
|
||||||
|
else:
|
||||||
|
self.done_write_cache_req_ids.append(msg)
|
||||||
|
|
||||||
|
def send_notify(self, req_ids, remote_ip, remote_port):
|
||||||
|
if not remote_ip or not remote_port:
|
||||||
|
logger.warning("Missing remote_ip or remote_port for notification")
|
||||||
|
return
|
||||||
|
|
||||||
|
path = make_zmq_path("tcp", remote_ip, remote_port)
|
||||||
|
|
||||||
|
if path not in self.paths:
|
||||||
|
ctx = zmq.Context.instance()
|
||||||
|
sock = make_zmq_socket(
|
||||||
|
ctx=ctx, path=path, socket_type=zmq.DEALER, bind=False
|
||||||
|
)
|
||||||
|
self.paths[path] = sock
|
||||||
|
|
||||||
|
req_list = req_ids if isinstance(req_ids, list) else [req_ids]
|
||||||
|
|
||||||
|
sock = self.paths[path]
|
||||||
|
try:
|
||||||
|
for req_id in req_list:
|
||||||
|
if not isinstance(req_id, str):
|
||||||
|
logger.warning(
|
||||||
|
"Invalid req_id type: %s, expected str", type(req_id)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
sock.send(req_id.encode("utf-8"))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to send notification to %s: %s", path, e)
|
||||||
|
self.paths.pop(path, None)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def pop_finished_req_ids(self):
|
||||||
|
# producer invocation: get the set of completed requests at the decode
|
||||||
|
with self.lock:
|
||||||
|
done_send = set(self.done_req_ids)
|
||||||
|
self.done_req_ids = []
|
||||||
|
return done_send
|
||||||
|
|
||||||
|
def pop_finished_write_req_ids(self):
|
||||||
|
# Call the consumer in write mode to get the collection after write completion
|
||||||
|
with self.lock:
|
||||||
|
done_write_cache = set(self.done_write_cache_req_ids)
|
||||||
|
self.done_write_cache_req_ids = []
|
||||||
|
return done_write_cache
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
logger.debug("Closing MoRIIOWrapper and cleaning up ZMQ sockets")
|
||||||
|
for path, sock in self.paths.items():
|
||||||
|
try:
|
||||||
|
sock.close(linger=0)
|
||||||
|
logger.debug("Closed ZMQ socket for path: %s", path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Error closing ZMQ socket for path %s: %s", path, e)
|
||||||
|
self.paths.clear()
|
||||||
Loading…
x
Reference in New Issue
Block a user