mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-18 12:26:59 +08:00
refine code
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
parent
fd63437837
commit
38d51f6dd8
@ -37,7 +37,7 @@ from vllm.distributed.parallel_state import (
|
|||||||
)
|
)
|
||||||
from vllm.forward_context import ForwardContext
|
from vllm.forward_context import ForwardContext
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket
|
from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket, get_open_port
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.request import RequestStatus
|
from vllm.v1.request import RequestStatus
|
||||||
|
|
||||||
@ -69,7 +69,8 @@ class MoRIIOConstants:
|
|||||||
|
|
||||||
PING_INTERVAL = 5
|
PING_INTERVAL = 5
|
||||||
MAX_PING_RETRIES = 100000
|
MAX_PING_RETRIES = 100000
|
||||||
|
DEFAULT_HANDSHAKE_PORT = "6301"
|
||||||
|
DEFAULT_NOTIFY_PORT="61005"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from mori.io import (
|
from mori.io import (
|
||||||
@ -78,6 +79,8 @@ try:
|
|||||||
IOEngine,
|
IOEngine,
|
||||||
IOEngineConfig,
|
IOEngineConfig,
|
||||||
MemoryDesc,
|
MemoryDesc,
|
||||||
|
PollCqMode,
|
||||||
|
RdmaBackendConfig
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("MoRIIO is available")
|
logger.info("MoRIIO is available")
|
||||||
@ -192,11 +195,12 @@ class TransferError(MoRIIOError):
|
|||||||
|
|
||||||
def get_moriio_mode() -> MoRIIOMode:
|
def get_moriio_mode() -> MoRIIOMode:
|
||||||
read_mode = os.environ.get("MORIIO_CONNECTOR_READ_MODE", "false").lower()
|
read_mode = os.environ.get("MORIIO_CONNECTOR_READ_MODE", "false").lower()
|
||||||
# logger.info(f"MoRIIO Connector Read Mode = {read_mode}")
|
logger.debug("MoRIIO Connector read_mode: %s", read_mode)
|
||||||
if read_mode in ("true", "1", "yes", "on"):
|
if read_mode in ("true", "1", "yes", "on"):
|
||||||
return MoRIIOMode.READ
|
return MoRIIOMode.READ
|
||||||
else:
|
else:
|
||||||
return MoRIIOMode.WRITE
|
return MoRIIOMode.WRITE
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_port_offset(dp_rank: int, tp_rank: int, tp_size: int = 1) -> int:
|
def get_port_offset(dp_rank: int, tp_rank: int, tp_size: int = 1) -> int:
|
||||||
@ -220,19 +224,21 @@ class MoRIIOConfig:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_vllm_config(cls, vllm_config: VllmConfig) -> "MoRIIOConfig":
|
def from_vllm_config(cls, vllm_config: VllmConfig) -> "MoRIIOConfig":
|
||||||
|
|
||||||
# Port Configuration:
|
# Port Configuration:
|
||||||
# local_ping_port -> Outgoing heartbeat to proxy(only rank0 need it)
|
# local_ping_port -> Outgoing heartbeat to proxy
|
||||||
# proxy_ping_port -> Remote proxy's heartbeat ingress port
|
# proxy_ping_port -> Remote proxy's heartbeat ingress port
|
||||||
# http_port -> Instance's HTTP service endpoint
|
# http_port -> Instance's HTTP service endpoint
|
||||||
# local_kv_port -> KV service port for Mori engine
|
# local_kv_port -> service port for mori engine
|
||||||
# notify_port -> For synchronizing stages between nodes
|
# notify_port -> For synchronizing stages between prefill and decode
|
||||||
|
# handshake_port -> For initial handshake between mori engine
|
||||||
|
|
||||||
|
#TODO : merge notify_port and handshake_port to simplify port management, supports non-contiguous ports
|
||||||
|
|
||||||
kv_transfer_config = vllm_config.kv_transfer_config
|
kv_transfer_config = vllm_config.kv_transfer_config
|
||||||
extra_config = kv_transfer_config.kv_connector_extra_config
|
extra_config = kv_transfer_config.kv_connector_extra_config
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||||
base_kv_port = int(kv_transfer_config.kv_port)
|
|
||||||
base_ping_port = int(extra_config["local_ping_port"])
|
|
||||||
base_notify_port = int(extra_config["notify_port"])
|
base_notify_port = int(extra_config["notify_port"])
|
||||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
@ -240,9 +246,9 @@ class MoRIIOConfig:
|
|||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
local_ip=get_ip(),
|
local_ip=get_ip(),
|
||||||
local_kv_port=base_kv_port + port_offset,
|
local_kv_port=get_open_port(),
|
||||||
proxy_ip=extra_config["proxy_ip"],
|
proxy_ip=extra_config["proxy_ip"],
|
||||||
local_ping_port=base_ping_port + port_offset,
|
local_ping_port=get_open_port(),
|
||||||
proxy_ping_port=int(extra_config["proxy_ping_port"]),
|
proxy_ping_port=int(extra_config["proxy_ping_port"]),
|
||||||
http_port=int(extra_config["http_port"]),
|
http_port=int(extra_config["http_port"]),
|
||||||
handshake_port=int(extra_config["handshake_port"]),
|
handshake_port=int(extra_config["handshake_port"]),
|
||||||
@ -545,7 +551,7 @@ class MoRIIOWrapper:
|
|||||||
self.notify_thread = None
|
self.notify_thread = None
|
||||||
self.sock = None
|
self.sock = None
|
||||||
self.sessions: list[IOEngine.Session] = []
|
self.sessions: list[IOEngine.Session] = []
|
||||||
self.paths = {}
|
self.paths: dict[str, zmq.Socket] = {}
|
||||||
|
|
||||||
def set_moriio_engine(self, moriio_engine):
|
def set_moriio_engine(self, moriio_engine):
|
||||||
assert moriio_engine is not None, (
|
assert moriio_engine is not None, (
|
||||||
@ -554,7 +560,17 @@ class MoRIIOWrapper:
|
|||||||
self.moriio_engine = moriio_engine
|
self.moriio_engine = moriio_engine
|
||||||
|
|
||||||
def set_backend_type(self, backend_type):
|
def set_backend_type(self, backend_type):
|
||||||
self.moriio_engine.create_backend(backend_type)
|
qp_per_transfer = int(os.getenv("VLLM_MORI_QP_PER_TRANSFER", "1"))
|
||||||
|
post_batch_size = int(os.getenv("VLLM_MORI_POST_BATCH_SIZE", "-1"))
|
||||||
|
num_worker_threads = int(os.getenv("VLLM_MORI_NUM_WORKERS", "1"))
|
||||||
|
poll_mode = PollCqMode.POLLING
|
||||||
|
rdma_cfg = RdmaBackendConfig(
|
||||||
|
qp_per_transfer,
|
||||||
|
post_batch_size,
|
||||||
|
num_worker_threads,
|
||||||
|
poll_mode,
|
||||||
|
)
|
||||||
|
self.moriio_engine.create_backend(backend_type, rdma_cfg)
|
||||||
|
|
||||||
def get_agent_metadata(self):
|
def get_agent_metadata(self):
|
||||||
engine_metadata = self.moriio_engine.get_engine_desc()
|
engine_metadata = self.moriio_engine.get_engine_desc()
|
||||||
@ -700,6 +716,8 @@ class MoRIIOWrapper:
|
|||||||
raise MoRIIOError(f"Unhandled message format: {msg_str}")
|
raise MoRIIOError(f"Unhandled message format: {msg_str}")
|
||||||
|
|
||||||
def _handle_structured_message(self, data: dict):
|
def _handle_structured_message(self, data: dict):
|
||||||
|
|
||||||
|
assert get_role()==ROLE.PRODUCER, "Only prefill can get block messages"
|
||||||
req_id = data["req_id"]
|
req_id = data["req_id"]
|
||||||
block_notify_list = data.get("block_notify_list", [])
|
block_notify_list = data.get("block_notify_list", [])
|
||||||
decode_dp_rank = data.get("decode_rank", 0)
|
decode_dp_rank = data.get("decode_rank", 0)
|
||||||
@ -882,16 +900,12 @@ class MoRIIOConnector(KVConnectorBase_V1):
|
|||||||
extra_config = kv_transfer_config.kv_connector_extra_config
|
extra_config = kv_transfer_config.kv_connector_extra_config
|
||||||
|
|
||||||
if "handshake_port" not in extra_config or not extra_config["handshake_port"]:
|
if "handshake_port" not in extra_config or not extra_config["handshake_port"]:
|
||||||
extra_config["handshake_port"] = "6301"
|
extra_config["handshake_port"] = MoRIIOConstants.DEFAULT_HANDSHAKE_PORT
|
||||||
|
|
||||||
if "notify_port" not in extra_config or not extra_config["notify_port"]:
|
if "notify_port" not in extra_config or not extra_config["notify_port"]:
|
||||||
extra_config["notify_port"] = "61005"
|
extra_config["notify_port"] = MoRIIOConstants.DEFAULT_NOTIFY_PORT
|
||||||
|
|
||||||
if "local_ping_port" not in extra_config or not extra_config["local_ping_port"]:
|
|
||||||
extra_config["local_ping_port"] = "7583"
|
|
||||||
|
|
||||||
if not kv_transfer_config.kv_port:
|
|
||||||
kv_transfer_config.kv_port = "7305"
|
|
||||||
|
|
||||||
def get_num_new_matched_tokens(
|
def get_num_new_matched_tokens(
|
||||||
self, request: "Request", num_computed_tokens: int
|
self, request: "Request", num_computed_tokens: int
|
||||||
@ -1180,7 +1194,7 @@ class MoRIIOConnectorScheduler:
|
|||||||
|
|
||||||
if new_block_ids is not None:
|
if new_block_ids is not None:
|
||||||
block_ids = new_block_ids[0]
|
block_ids = new_block_ids[0]
|
||||||
|
#TODO : hybrid attn, etc
|
||||||
req, existing_blocks = self._reqs_need_pending_save[req_id]
|
req, existing_blocks = self._reqs_need_pending_save[req_id]
|
||||||
updated_blocks = list(existing_blocks) + (block_ids)
|
updated_blocks = list(existing_blocks) + (block_ids)
|
||||||
self._reqs_need_pending_save[req_id] = (req, updated_blocks)
|
self._reqs_need_pending_save[req_id] = (req, updated_blocks)
|
||||||
@ -2261,6 +2275,7 @@ class MoRIIOConnectorWorker:
|
|||||||
sess_idx = list(self.layer_name_to_local_kv_cache_metadata.keys()).index(
|
sess_idx = list(self.layer_name_to_local_kv_cache_metadata.keys()).index(
|
||||||
layer_name
|
layer_name
|
||||||
)
|
)
|
||||||
|
#TODO : apply multi-session batch-read when moriio support it
|
||||||
transfer_status = self.moriio_wrapper.read_remote_data(
|
transfer_status = self.moriio_wrapper.read_remote_data(
|
||||||
offs[2], offs[0], offs[1], sessions[sess_idx]
|
offs[2], offs[0], offs[1], sessions[sess_idx]
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user