mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-03 09:17:03 +08:00
refine code
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
parent
68a2333339
commit
70ea1b2460
@ -7,6 +7,13 @@ import pickle
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
import msgpack
|
||||
import msgspec
|
||||
import numpy as np
|
||||
import torch
|
||||
import zmq
|
||||
import logging
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
@ -14,12 +21,6 @@ from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from weakref import ref as weakref_ref
|
||||
|
||||
import msgpack
|
||||
import msgspec
|
||||
import numpy as np
|
||||
import torch
|
||||
import zmq
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
@ -50,7 +51,6 @@ if TYPE_CHECKING:
|
||||
from dataclasses import field
|
||||
from enum import Enum
|
||||
from queue import Empty, Queue
|
||||
import logging
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -70,7 +70,9 @@ class MoRIIOConstants:
|
||||
|
||||
# Default GPU count per node for standard configurations
|
||||
RANK_PER_NODE = 8
|
||||
|
||||
|
||||
PING_INTERVAL = 5
|
||||
MAX_PING_RETRIES = 100000
|
||||
|
||||
try:
|
||||
import mori
|
||||
@ -80,7 +82,6 @@ try:
|
||||
IOEngine,
|
||||
IOEngineConfig,
|
||||
MemoryDesc,
|
||||
StatusCode,
|
||||
)
|
||||
|
||||
logger.info("MoRIIO is available")
|
||||
@ -404,10 +405,12 @@ class MoRIIOWriter:
|
||||
return
|
||||
|
||||
# Wait for CUDA event
|
||||
#The attention computation of the current layer cannot overlap with the kv transfer task, otherwise it will cause precision issues.
|
||||
#This event is used to synchronize the kv transfer and computation tasks.
|
||||
task.event.synchronize()
|
||||
|
||||
# Update engine ID with DP rank
|
||||
task.dst_engine_id = f"{task.dst_engine_id}_dp{request_info.decode_dp_rank}"
|
||||
task.dst_engine_id= self.worker.get_engine_name_with_dp(task.dst_engine_id,request_info.decode_dp_rank)
|
||||
|
||||
# Get or create sessions
|
||||
sessions = self.worker._get_built_session(task.dst_engine_id)
|
||||
@ -665,7 +668,6 @@ class MoRIIOWrapper:
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message: {e}")
|
||||
raise HandshakeError(f"Error processing message: {e}") from e
|
||||
continue
|
||||
|
||||
self.notify_thread = threading.Thread(
|
||||
target=_async_wait, daemon=True, name="moriio-notify-listener"
|
||||
@ -685,7 +687,6 @@ class MoRIIOWrapper:
|
||||
data = msgpack.loads(msg)
|
||||
if isinstance(data, dict) and "req_id" in data:
|
||||
self._handle_structured_message(data)
|
||||
handled = True
|
||||
|
||||
return
|
||||
except (msgpack.exceptions.ExtraData, msgpack.exceptions.UnpackException):
|
||||
@ -841,6 +842,7 @@ class MoRIIOConnector(KVConnectorBase_V1):
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: Optional["KVCacheConfig"] = None,
|
||||
):
|
||||
super().__init__(vllm_config, role)
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
# assert vllm_config.kv_transfer_config.engine_id is not None
|
||||
self.engine_id = (
|
||||
@ -1149,10 +1151,8 @@ class MoRIIOConnectorScheduler:
|
||||
)
|
||||
self._reqs_need_pending_save[req_id] = (req, updated_blocks)
|
||||
if (
|
||||
len(
|
||||
self._reqs_need_pending_save[req_id][1]
|
||||
len(self._reqs_need_pending_save[req_id][1])
|
||||
* self.block_size
|
||||
)
|
||||
>= req.num_prompt_tokens
|
||||
):
|
||||
meta.add_new_req(
|
||||
@ -1363,13 +1363,9 @@ class MoRIIOConnectorWorker:
|
||||
)
|
||||
# Agent.
|
||||
self.moriio_wrapper = MoRIIOWrapper(tp_rank=self.tp_rank, dp_rank=self.dp_rank)
|
||||
|
||||
self.moriio_wrapper.set_moriio_engine(self.moriio_engine)
|
||||
|
||||
self.moriio_wrapper.set_backend_type(BackendType.RDMA)
|
||||
|
||||
self.moriio_wrapper.notify_port = self.moriio_config.notify_port
|
||||
|
||||
self.local_kv_cache_metadata = []
|
||||
self.local_kv_cache_size = []
|
||||
self.layer_name_to_local_kv_cache_metadata: dict[str, list[Any]] = dict()
|
||||
@ -1450,7 +1446,6 @@ class MoRIIOConnectorWorker:
|
||||
self.use_mla = self.model_config.use_mla
|
||||
self.built_session = False
|
||||
self.built_write_session: defaultdict[str, list] = defaultdict(list)
|
||||
self._write_session_lock = threading.Lock()
|
||||
self.debug_cache = []
|
||||
backend = get_attn_backend(
|
||||
self.model_config.get_head_size(),
|
||||
@ -1467,15 +1462,6 @@ class MoRIIOConnectorWorker:
|
||||
# self._use_flashinfer = attn_backend == _Backend.FLASHINFER
|
||||
logger.debug("Detected attention backend %s", self.backend_name)
|
||||
|
||||
self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
|
||||
|
||||
####write worker###
|
||||
# self._write_task_q: Queue[WriteTask] = Queue()
|
||||
# self._write_worker_started = False
|
||||
# self._write_worker_lock = threading.Lock()
|
||||
# self._deferred_tasks: list[WriteTask] = []
|
||||
# ####write worker###
|
||||
|
||||
def schedule_write_blocks(
|
||||
self,
|
||||
request_id: str,
|
||||
@ -1539,8 +1525,6 @@ class MoRIIOConnectorWorker:
|
||||
return self.built_write_session[remote_engine_id]
|
||||
|
||||
def _ping(self, zmq_context):
|
||||
PING_INTERVAL = 5
|
||||
MAX_RETRIES = 100000
|
||||
|
||||
http_request_address = f"http://{self.request_address}/v1/completions"
|
||||
role = "P" if self.is_producer else "D"
|
||||
@ -1585,13 +1569,13 @@ class MoRIIOConnectorWorker:
|
||||
retry_count += 1
|
||||
|
||||
finally:
|
||||
if retry_count >= MAX_RETRIES:
|
||||
if retry_count >= MoRIIOConstants.MAX_RETRIES:
|
||||
logger.error(
|
||||
f"Max retries ({MAX_RETRIES}) exceeded. Stopping ping loop."
|
||||
f"Max retries ({MoRIIOConstants.MAX_RETRIES}) exceeded. Stopping ping loop."
|
||||
)
|
||||
break
|
||||
|
||||
time.sleep(PING_INTERVAL)
|
||||
time.sleep(MoRIIOConstants.PING_INTERVAL)
|
||||
index += 1
|
||||
|
||||
def handle_proxy_request(self):
|
||||
@ -1603,7 +1587,6 @@ class MoRIIOConnectorWorker:
|
||||
socks = dict(self.poller.poll())
|
||||
logger.debug(f"handle_proxy_request: {socks = }")
|
||||
|
||||
# TODO: inkcherry , check here?
|
||||
if self.metadata_socket not in socks:
|
||||
continue
|
||||
|
||||
@ -1665,8 +1648,6 @@ class MoRIIOConnectorWorker:
|
||||
"MoRIIO handshake listener received done recv for req %s",
|
||||
req_id.decode(),
|
||||
)
|
||||
else:
|
||||
pass
|
||||
|
||||
def _moriio_handshake(
|
||||
self,
|
||||
@ -1709,9 +1690,7 @@ class MoRIIOConnectorWorker:
|
||||
)
|
||||
|
||||
self.moriio_wrapper.remote_engine_ip = host
|
||||
remote_agent_name = self.moriio_wrapper.register_remote_engine(
|
||||
metadata.agent_metadata
|
||||
)
|
||||
|
||||
remote_agent_name = EngineDesc.unpack(metadata.agent_metadata).key
|
||||
|
||||
logger.info(
|
||||
@ -1873,8 +1852,6 @@ class MoRIIOConnectorWorker:
|
||||
if layer_name not in self.layer_name_to_local_kv_cache_metadata:
|
||||
self.layer_name_to_local_kv_cache_metadata[layer_name] = []
|
||||
|
||||
# for cache in cache_list:
|
||||
# moriio_mem_metadata = self.moriio_wrapper.register_local_tensor(cache)
|
||||
moriio_mem_metadata = self.moriio_wrapper.register_local_tensor(kv_cache)
|
||||
self.layer_name_to_local_kv_cache_metadata[layer_name].append(
|
||||
moriio_mem_metadata
|
||||
@ -1942,7 +1919,7 @@ class MoRIIOConnectorWorker:
|
||||
to track which workers are done.
|
||||
"""
|
||||
|
||||
done_sending, done_recving = set(), set()
|
||||
done_sending = set()
|
||||
|
||||
if self.is_producer:
|
||||
done_sending = self.moriio_wrapper.pop_finished_req_ids()
|
||||
@ -1991,7 +1968,6 @@ class MoRIIOConnectorWorker:
|
||||
remote_engine_id = None
|
||||
|
||||
for req_id, meta in metadata.reqs_to_save.items():
|
||||
remote_engine_id = meta.remote_engine_id
|
||||
# we only need to check if dp0 in rank
|
||||
remote_engine_id = (
|
||||
str(meta.remote_host) + ":" + str(meta.remote_handshake_port)
|
||||
@ -2030,13 +2006,15 @@ class MoRIIOConnectorWorker:
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
|
||||
def get_engine_name_with_dp(self, engine_name,dp_rank):
|
||||
return f"{engine_name}_dp{dp_rank}"
|
||||
def start_load_kv(self, metadata: MoRIIOConnectorMetadata):
|
||||
"""
|
||||
Start loading by triggering non-blocking moriio_xfer.
|
||||
We check for these trnxs to complete in each step().
|
||||
"""
|
||||
# print("start load kv")
|
||||
if self.is_producer:
|
||||
self.moriio_wrapper.async_wait_reqid()
|
||||
return
|
||||
@ -2051,7 +2029,7 @@ class MoRIIOConnectorWorker:
|
||||
str(meta.remote_host) + ":" + str(meta.remote_handshake_port)
|
||||
)
|
||||
meta.remote_engine_id = remote_engine_id
|
||||
dp0_remote_engine_id = f"{remote_engine_id}_dp0"
|
||||
dp0_remote_engine_id = self.get_engine_name_with_dp(remote_engine_id,0)
|
||||
if dp0_remote_engine_id not in self._remote_agents:
|
||||
# Initiate handshake with remote engine to exchange metadata.
|
||||
with self._handshake_lock:
|
||||
@ -2216,7 +2194,6 @@ class MoRIIOConnectorWorker:
|
||||
blknum, blksize, hs = self.kv_cache_shape
|
||||
hn = 1
|
||||
block_stride = stride[0]
|
||||
ktov_stride = None
|
||||
else:
|
||||
_, blknum, blksize, hn, hs = self.kv_cache_shape
|
||||
ktov_stride = stride[0]
|
||||
@ -2260,21 +2237,20 @@ class MoRIIOConnectorWorker:
|
||||
return
|
||||
# we only test TP<->TP in read mode
|
||||
# assert self.dp_rank>0, "only test TP<->TP in read mode"
|
||||
dst_engine_id += "_dp0"
|
||||
sessions = self._get_built_session(dst_engine_id)
|
||||
dp0_engine_id=self.get_engine_name_with_dp(dst_engine_id,0)
|
||||
sessions = self._get_built_session(dp0_engine_id)
|
||||
|
||||
first_layer = list(self.layer_name_to_local_kv_cache_metadata.keys())[0]
|
||||
offs = self._compute_block_transfer_offsets(
|
||||
first_layer, local_block_ids, remote_block_ids
|
||||
)
|
||||
a, b, c = offs[0], offs[1], offs[2]
|
||||
|
||||
for layer_name in self.layer_name_to_local_kv_cache_metadata.keys():
|
||||
sess_idx = list(self.layer_name_to_local_kv_cache_metadata.keys()).index(
|
||||
layer_name
|
||||
)
|
||||
transfer_status = self.moriio_wrapper.read_remote_data(
|
||||
c, a, b, sessions[sess_idx]
|
||||
offs[0], offs[1], offs[2], sessions[sess_idx]
|
||||
)
|
||||
|
||||
self._recving_transfers[request_id].append(transfer_status)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user