mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-03 18:47:04 +08:00
more
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
parent
72ccb5d77c
commit
4776e2ddcf
@ -37,7 +37,12 @@ from vllm.distributed.parallel_state import (
|
||||
)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket, get_open_port
|
||||
from vllm.utils.network_utils import (
|
||||
get_ip,
|
||||
get_open_port,
|
||||
make_zmq_path,
|
||||
make_zmq_socket,
|
||||
)
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.request import RequestStatus
|
||||
|
||||
@ -70,7 +75,8 @@ class MoRIIOConstants:
|
||||
PING_INTERVAL = 5
|
||||
MAX_PING_RETRIES = 100000
|
||||
DEFAULT_HANDSHAKE_PORT = "6301"
|
||||
DEFAULT_NOTIFY_PORT="61005"
|
||||
DEFAULT_NOTIFY_PORT = "61005"
|
||||
|
||||
|
||||
try:
|
||||
from mori.io import (
|
||||
@ -80,7 +86,7 @@ try:
|
||||
IOEngineConfig,
|
||||
MemoryDesc,
|
||||
PollCqMode,
|
||||
RdmaBackendConfig
|
||||
RdmaBackendConfig,
|
||||
)
|
||||
|
||||
logger.info("MoRIIO is available")
|
||||
@ -200,7 +206,6 @@ def get_moriio_mode() -> MoRIIOMode:
|
||||
return MoRIIOMode.READ
|
||||
else:
|
||||
return MoRIIOMode.WRITE
|
||||
|
||||
|
||||
|
||||
def get_port_offset(dp_rank: int, tp_rank: int, tp_size: int = 1) -> int:
|
||||
@ -224,7 +229,6 @@ class MoRIIOConfig:
|
||||
|
||||
@classmethod
|
||||
def from_vllm_config(cls, vllm_config: VllmConfig) -> "MoRIIOConfig":
|
||||
|
||||
# Port Configuration:
|
||||
# local_ping_port -> Outgoing heartbeat to proxy
|
||||
# proxy_ping_port -> Remote proxy's heartbeat ingress port
|
||||
@ -233,8 +237,8 @@ class MoRIIOConfig:
|
||||
# notify_port -> For synchronizing stages between prefill and decode
|
||||
# handshake_port -> For initial handshake between mori engine
|
||||
|
||||
#TODO : merge notify_port and handshake_port to simplify port management, supports non-contiguous ports
|
||||
|
||||
# TODO : merge notify_port and handshake_port to simplify port management, supports non-contiguous ports
|
||||
|
||||
kv_transfer_config = vllm_config.kv_transfer_config
|
||||
extra_config = kv_transfer_config.kv_connector_extra_config
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
@ -716,8 +720,7 @@ class MoRIIOWrapper:
|
||||
raise MoRIIOError(f"Unhandled message format: {msg_str}")
|
||||
|
||||
def _handle_structured_message(self, data: dict):
|
||||
|
||||
assert get_role()==ROLE.PRODUCER, "Only prefill can get block messages"
|
||||
assert get_role() == ROLE.PRODUCER, "Only prefill can get block messages"
|
||||
req_id = data["req_id"]
|
||||
block_notify_list = data.get("block_notify_list", [])
|
||||
decode_dp_rank = data.get("decode_rank", 0)
|
||||
@ -887,8 +890,9 @@ class MoRIIOConnector(KVConnectorBase_V1):
|
||||
self.connector_scheduler = None
|
||||
self.connector_worker = MoRIIOConnectorWorker(vllm_config, self.engine_id)
|
||||
logger.info(
|
||||
"Initialized MoRIIO Connector,engine_id:%s,role: %s",
|
||||
self.engine_id, role.value
|
||||
"Initialized MoRIIO Connector,engine_id:%s,role: %s",
|
||||
self.engine_id,
|
||||
role.value,
|
||||
)
|
||||
|
||||
############################################################
|
||||
@ -905,8 +909,6 @@ class MoRIIOConnector(KVConnectorBase_V1):
|
||||
if "notify_port" not in extra_config or not extra_config["notify_port"]:
|
||||
extra_config["notify_port"] = MoRIIOConstants.DEFAULT_NOTIFY_PORT
|
||||
|
||||
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request", num_computed_tokens: int
|
||||
) -> tuple[int, bool]:
|
||||
@ -1194,7 +1196,7 @@ class MoRIIOConnectorScheduler:
|
||||
|
||||
if new_block_ids is not None:
|
||||
block_ids = new_block_ids[0]
|
||||
#TODO : hybrid attn, etc
|
||||
# TODO : hybrid attn, etc
|
||||
req, existing_blocks = self._reqs_need_pending_save[req_id]
|
||||
updated_blocks = list(existing_blocks) + (block_ids)
|
||||
self._reqs_need_pending_save[req_id] = (req, updated_blocks)
|
||||
@ -1356,26 +1358,17 @@ class MoRIIOConnectorWorker:
|
||||
self._ping_thread = None
|
||||
self._writer = MoRIIOWriter(self)
|
||||
|
||||
engine_suffix = (
|
||||
f"{self.moriio_config.local_ip}:{self.moriio_config.handshake_port}"
|
||||
f":tp {self.tp_rank}:dp {self.dp_rank}"
|
||||
role = "producer" if self.is_producer else "consumer"
|
||||
engine_suffix = f"{self.moriio_config.local_ip}:{self.moriio_config.handshake_port}:tp{self.tp_rank}:dp{self.dp_rank}"
|
||||
self.moriio_engine = IOEngine(
|
||||
f"{role}:{engine_suffix}",
|
||||
IOEngineConfig(
|
||||
self.moriio_config.local_ip, self.moriio_config.local_kv_port
|
||||
),
|
||||
)
|
||||
if not self.is_producer:
|
||||
self.moriio_engine = IOEngine(
|
||||
"consumer:" + engine_suffix,
|
||||
IOEngineConfig(
|
||||
self.moriio_config.local_ip, self.moriio_config.local_kv_port
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.moriio_engine = IOEngine(
|
||||
"producer:" + engine_suffix,
|
||||
IOEngineConfig(
|
||||
self.moriio_config.local_ip, self.moriio_config.local_kv_port
|
||||
),
|
||||
)
|
||||
logger.debug(
|
||||
"build MORI IOEngine %s:%s",
|
||||
"build MORI IOEngine %s (ip=%s port=%s)",
|
||||
f"{role}:{engine_suffix}",
|
||||
self.moriio_config.local_ip,
|
||||
self.moriio_config.local_kv_port,
|
||||
)
|
||||
@ -1604,18 +1597,6 @@ class MoRIIOConnectorWorker:
|
||||
if should_break:
|
||||
break
|
||||
|
||||
# def handle_proxy_request(self):
|
||||
# if self.is_producer:
|
||||
# raise NotImplementedError(
|
||||
# "prefill instance doesn't need to send kv cache in pull mode"
|
||||
# )
|
||||
# while True:
|
||||
# socks = dict(self.poller.poll())
|
||||
# logger.debug("handle_proxy_request: socks = %s", socks)
|
||||
|
||||
# if self.metadata_socket not in socks:
|
||||
# continue
|
||||
|
||||
def close(self):
|
||||
if hasattr(self, "_handshake_initiation_executor"):
|
||||
self._handshake_initiation_executor.shutdown(wait=False)
|
||||
@ -2264,7 +2245,7 @@ class MoRIIOConnectorWorker:
|
||||
sess_idx = list(self.layer_name_to_local_kv_cache_metadata.keys()).index(
|
||||
layer_name
|
||||
)
|
||||
#TODO : apply multi-session batch-read when moriio support it
|
||||
# TODO : apply multi-session batch-read when moriio support it
|
||||
transfer_status = self.moriio_wrapper.read_remote_data(
|
||||
offs[2], offs[0], offs[1], sessions[sess_idx]
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user