mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-03 09:27:03 +08:00
format
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
parent
4f592ae696
commit
b60ee86585
@ -1,19 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import pickle
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
import msgpack
|
||||
import msgspec
|
||||
import numpy as np
|
||||
import torch
|
||||
import zmq
|
||||
import logging
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
@ -21,6 +15,12 @@ from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from weakref import ref as weakref_ref
|
||||
|
||||
import msgpack
|
||||
import msgspec
|
||||
import numpy as np
|
||||
import torch
|
||||
import zmq
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
@ -68,10 +68,10 @@ class MoRIIOConstants:
|
||||
OVER = b"OVER"
|
||||
COMPLETION_PREFIX = "cmpl"
|
||||
|
||||
|
||||
PING_INTERVAL = 5
|
||||
MAX_PING_RETRIES = 100000
|
||||
|
||||
|
||||
try:
|
||||
import mori
|
||||
from mori.io import (
|
||||
@ -202,7 +202,7 @@ def get_moriio_mode() -> MoRIIOMode:
|
||||
|
||||
|
||||
def get_port_offset(dp_rank: int, tp_rank: int, tp_size: int = 1) -> int:
|
||||
return ((dp_rank) * tp_size + tp_rank)
|
||||
return (dp_rank) * tp_size + tp_rank
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -316,7 +316,6 @@ class MoRIIOWriter:
|
||||
|
||||
def _write_worker_loop(self) -> None:
|
||||
"""Main loop for the write worker thread."""
|
||||
logger.info("Write worker loop started")
|
||||
|
||||
while True:
|
||||
# Process deferred tasks first
|
||||
@ -408,7 +407,9 @@ class MoRIIOWriter:
|
||||
task.event.synchronize()
|
||||
|
||||
# Update engine ID with DP rank
|
||||
task.dst_engine_id= self.worker.get_engine_name_with_dp(task.dst_engine_id,request_info.decode_dp_rank)
|
||||
task.dst_engine_id = self.worker.get_engine_name_with_dp(
|
||||
task.dst_engine_id, request_info.decode_dp_rank
|
||||
)
|
||||
|
||||
# Get or create sessions
|
||||
sessions = self.worker._get_built_session(task.dst_engine_id)
|
||||
@ -544,7 +545,7 @@ class MoRIIOWrapper:
|
||||
self.done_write_cache_req_ids = []
|
||||
self.notify_thread = None
|
||||
self.sock = None
|
||||
self.sessions: list["IOEngine.Session"] = []
|
||||
self.sessions: list[IOEngine.Session] = []
|
||||
self.paths = {}
|
||||
|
||||
def set_moriio_engine(self, moriio_engine):
|
||||
@ -968,7 +969,7 @@ class MoRIIOConnectorScheduler:
|
||||
"handshake_port"
|
||||
]
|
||||
)
|
||||
logger.info(f"==========> Initializing MoRIIO Scheduler {engine_id = }")
|
||||
logger.info(f"Initializing MoRIIO Scheduler {engine_id = }")
|
||||
|
||||
self.side_notify_port = (
|
||||
self.vllm_config.kv_transfer_config.kv_connector_extra_config["notify_port"]
|
||||
@ -1149,7 +1150,7 @@ class MoRIIOConnectorScheduler:
|
||||
self._reqs_need_pending_save[req_id] = (req, updated_blocks)
|
||||
if (
|
||||
len(self._reqs_need_pending_save[req_id][1])
|
||||
* self.block_size
|
||||
* self.block_size
|
||||
>= req.num_prompt_tokens
|
||||
):
|
||||
meta.add_new_req(
|
||||
@ -1171,7 +1172,7 @@ class MoRIIOConnectorScheduler:
|
||||
|
||||
for req_id, (req, block_ids) in self._reqs_need_save.items():
|
||||
assert req.kv_transfer_params is not None
|
||||
if req.num_prompt_tokens > len(block_ids)*self.block_size:
|
||||
if req.num_prompt_tokens > len(block_ids) * self.block_size:
|
||||
# not last chunk prefill
|
||||
self._reqs_need_pending_save[req_id] = (req, block_ids)
|
||||
continue
|
||||
@ -1265,7 +1266,7 @@ class MoRIIOConnectorWorker:
|
||||
self.mode = get_moriio_mode()
|
||||
|
||||
logger.info("Initializing MoRIIO worker %s", engine_id)
|
||||
|
||||
|
||||
logging.getLogger("aiter").disabled = True
|
||||
|
||||
# Config.
|
||||
@ -1283,11 +1284,6 @@ class MoRIIOConnectorWorker:
|
||||
self.tp_rank = self.moriio_config.tp_rank
|
||||
self.dp_rank = self.moriio_config.dp_rank
|
||||
|
||||
logger.info(
|
||||
f"MoRIIO Worker init {self.tp_rank = },{self.dp_rank= }"
|
||||
f",{self.is_producer= }"
|
||||
)
|
||||
|
||||
self.local_ip = self.moriio_config.local_ip
|
||||
self.local_kv_port = self.moriio_config.local_kv_port
|
||||
self.proxy_ip = self.moriio_config.proxy_ip
|
||||
@ -1340,8 +1336,8 @@ class MoRIIOConnectorWorker:
|
||||
),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"build IOEngine %s:%s",
|
||||
logger.debug(
|
||||
"build MORI IOEngine %s:%s",
|
||||
self.moriio_config.local_ip,
|
||||
self.moriio_config.local_kv_port,
|
||||
)
|
||||
@ -1390,8 +1386,6 @@ class MoRIIOConnectorWorker:
|
||||
self.moriio_config.handshake_port
|
||||
+ get_port_offset(self.dp_rank, self.tp_rank)
|
||||
)
|
||||
logger.info(f"MoRIIO Worker init {self.tp_rank = },{self.dp_rank= }")
|
||||
logger.info(f"MoRIIO side channel_port port: {self.side_channel_port}")
|
||||
self.engine_id: EngineId = engine_id
|
||||
|
||||
self.world_size = get_tensor_model_parallel_world_size()
|
||||
@ -1522,7 +1516,6 @@ class MoRIIOConnectorWorker:
|
||||
return self.built_write_session[remote_engine_id]
|
||||
|
||||
def _ping(self, zmq_context):
|
||||
|
||||
http_request_address = f"http://{self.request_address}/v1/completions"
|
||||
role = "P" if self.is_producer else "D"
|
||||
|
||||
@ -1586,15 +1579,18 @@ class MoRIIOConnectorWorker:
|
||||
|
||||
if self.metadata_socket not in socks:
|
||||
continue
|
||||
|
||||
|
||||
def close(self):
|
||||
if hasattr(self, '_handshake_initiation_executor'):
|
||||
if hasattr(self, "_handshake_initiation_executor"):
|
||||
self._handshake_initiation_executor.shutdown(wait=False)
|
||||
|
||||
if hasattr(self, '_moriio_handshake_listener_t') and self._moriio_handshake_listener_t:
|
||||
|
||||
if (
|
||||
hasattr(self, "_moriio_handshake_listener_t")
|
||||
and self._moriio_handshake_listener_t
|
||||
):
|
||||
self._moriio_handshake_listener_t.join(timeout=0)
|
||||
|
||||
if hasattr(self, 'zmq_context') and self.zmq_context:
|
||||
|
||||
if hasattr(self, "zmq_context") and self.zmq_context:
|
||||
self.zmq_context.destroy(linger=0)
|
||||
self.zmq_context = None
|
||||
|
||||
@ -1621,12 +1617,9 @@ class MoRIIOConnectorWorker:
|
||||
|
||||
# Listen for new requests for metadata.
|
||||
host = "*"
|
||||
logger.info(
|
||||
f"======> mori handshake starting listening on baseport: {base_port}"
|
||||
)
|
||||
|
||||
path = make_zmq_path("tcp", host, base_port)
|
||||
logger.info(f"======> mori handshake starting listening on path: {path}")
|
||||
logger.debug(f" mori handshake starting listening on path: {path}")
|
||||
|
||||
with zmq_ctx(zmq.ROUTER, path) as sock:
|
||||
ready_event.set()
|
||||
@ -1636,20 +1629,20 @@ class MoRIIOConnectorWorker:
|
||||
msg != MoRIIOConstants.GET_META_MSG
|
||||
and msg != MoRIIOConstants.POP_DONE_RECV
|
||||
):
|
||||
logger.error("Connection listener got unexpected message %s", msg)
|
||||
logger.error("Connection listener got unexpected message")
|
||||
raise HandshakeError("handshake failed, unexpected msg type")
|
||||
elif msg == MoRIIOConstants.GET_META_MSG:
|
||||
sock.send_multipart(
|
||||
(identity, b"", encoded_data)
|
||||
) # send local mori io engine meta data
|
||||
logger.info("MoRIIO handshake listener sent metadata to %s")
|
||||
logger.debug("MoRIIO handshake listener sent metadata")
|
||||
# now we send tensor meta data for each block
|
||||
buf = pickle.dumps(layer_name_to_local_kv_cache_metadata)
|
||||
sock.send_multipart((identity, b"", buf))
|
||||
elif msg == MoRIIOConstants.POP_DONE_RECV:
|
||||
_, req_id = sock.recv_multipart()
|
||||
logger.info(
|
||||
"MoRIIO handshake listener received done recv for req %s",
|
||||
logger.debug(
|
||||
"MoRIIO handshake listener received done recv for req",
|
||||
req_id.decode(),
|
||||
)
|
||||
|
||||
@ -1671,10 +1664,7 @@ class MoRIIOConnectorWorker:
|
||||
|
||||
port_offset = get_port_offset(remote_dp_rank, self.tp_rank)
|
||||
path = make_zmq_path("tcp", host, port + port_offset)
|
||||
logger.info(
|
||||
"handshake Querying metadata on path: %s at remote rank %s",
|
||||
path,
|
||||
)
|
||||
logger.debug(f"handshake Querying metadata on path:{path}")
|
||||
|
||||
# Send query for the request.
|
||||
with zmq_ctx(zmq.DEALER, path) as sock:
|
||||
@ -1694,7 +1684,7 @@ class MoRIIOConnectorWorker:
|
||||
)
|
||||
|
||||
self.moriio_wrapper.remote_engine_ip = host
|
||||
remote_agent_name=self.moriio_wrapper.register_remote_engine(
|
||||
remote_agent_name = self.moriio_wrapper.register_remote_engine(
|
||||
metadata.agent_metadata
|
||||
)
|
||||
|
||||
@ -1748,7 +1738,7 @@ class MoRIIOConnectorWorker:
|
||||
def request_ready(_f: Future[Any], entry=(req_id, meta)):
|
||||
logger.info("MoRIIO handshake done for request %s", req_id)
|
||||
self._ready_requests.put(entry)
|
||||
self.load_ready_flag [remote_engine_id] = True
|
||||
self.load_ready_flag[remote_engine_id] = True
|
||||
self.write_ready_flags[remote_engine_id] = True
|
||||
|
||||
fut_list = []
|
||||
@ -1999,10 +1989,10 @@ class MoRIIOConnectorWorker:
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
def get_engine_name_with_dp(self, engine_name,dp_rank):
|
||||
|
||||
def get_engine_name_with_dp(self, engine_name, dp_rank):
|
||||
return f"{engine_name}_dp{dp_rank}"
|
||||
|
||||
def start_load_kv(self, metadata: MoRIIOConnectorMetadata):
|
||||
"""
|
||||
Start loading by triggering non-blocking moriio_xfer.
|
||||
@ -2022,7 +2012,7 @@ class MoRIIOConnectorWorker:
|
||||
str(meta.remote_host) + ":" + str(meta.remote_handshake_port)
|
||||
)
|
||||
meta.remote_engine_id = remote_engine_id
|
||||
dp0_remote_engine_id = self.get_engine_name_with_dp(remote_engine_id,0)
|
||||
dp0_remote_engine_id = self.get_engine_name_with_dp(remote_engine_id, 0)
|
||||
if dp0_remote_engine_id not in self._remote_agents:
|
||||
# Initiate handshake with remote engine to exchange metadata.
|
||||
with self._handshake_lock:
|
||||
@ -2038,20 +2028,21 @@ class MoRIIOConnectorWorker:
|
||||
self._read_blocks_for_req(req_id, meta)
|
||||
# Start transfers for requests whose handshakes have now finished.
|
||||
|
||||
while True:
|
||||
while True:
|
||||
if (
|
||||
self._ready_requests.empty()
|
||||
and remote_engine_id not in self.load_ready_flag
|
||||
and wait_handshake_readd_req
|
||||
):
|
||||
continue
|
||||
elif not self._ready_requests.empty() and remote_engine_id in self.load_ready_flag:
|
||||
elif (
|
||||
not self._ready_requests.empty()
|
||||
and remote_engine_id in self.load_ready_flag
|
||||
):
|
||||
self._read_blocks_for_req(*self._ready_requests.get_nowait())
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
|
||||
self._reqs_to_send.update(metadata.reqs_to_send)
|
||||
|
||||
@ -2089,11 +2080,6 @@ class MoRIIOConnectorWorker:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _is_first_layer(self, layer_name):
|
||||
if layer_name == list(self.kv_caches.keys())[0]:
|
||||
return True
|
||||
return False
|
||||
|
||||
def merge_contiguous_blocks(
|
||||
self,
|
||||
offsets_local: list[int],
|
||||
@ -2230,9 +2216,8 @@ class MoRIIOConnectorWorker:
|
||||
) -> None:
|
||||
if self.mode == MoRIIOMode.WRITE:
|
||||
return
|
||||
|
||||
|
||||
dp0_engine_id=self.get_engine_name_with_dp(dst_engine_id,0)
|
||||
|
||||
dp0_engine_id = self.get_engine_name_with_dp(dst_engine_id, 0)
|
||||
sessions = self._get_built_session(dp0_engine_id)
|
||||
|
||||
first_layer = list(self.layer_name_to_local_kv_cache_metadata.keys())[0]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user