mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-03 09:17:03 +08:00
break long line
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
parent
f75eecde0a
commit
e0885e52d9
@ -73,7 +73,6 @@ class MoRIIOConstants:
|
||||
|
||||
|
||||
try:
|
||||
import mori
|
||||
from mori.io import (
|
||||
BackendType,
|
||||
EngineDesc,
|
||||
@ -260,7 +259,8 @@ class MoRIIOConfig:
|
||||
|
||||
|
||||
class MoRIIOWriter:
|
||||
"""Handles write operations for KV cache transfers. Implements distributed KV cache transfer using the MoRIIO library
|
||||
"""Handles write operations for KV cache transfers.
|
||||
Implements distributed KV cache transfer using the MoRIIO library
|
||||
for RDMA-based communication between prefill and decode instances."""
|
||||
|
||||
def __init__(self, worker: "MoRIIOConnectorWorker"):
|
||||
@ -400,7 +400,9 @@ class MoRIIOWriter:
|
||||
return
|
||||
|
||||
# Wait for CUDA event
|
||||
# The attention computation of the current layer cannot overlap with the kv transfer task, otherwise it will cause precision issues.
|
||||
# The attention computation of the current layer cannot
|
||||
# overlap with the kv transfer task,
|
||||
# otherwise it will cause precision issues.
|
||||
# This event is used to synchronize the kv transfer and computation tasks.
|
||||
task.event.synchronize()
|
||||
|
||||
@ -497,8 +499,8 @@ class MoRIIOWriter:
|
||||
remote_port = task.remote_notify_port + get_port_offset(
|
||||
request_info.decode_dp_rank, self.worker.tp_rank
|
||||
)
|
||||
# TODO:
|
||||
# Consider using RDMA immediate data in decode side to eliminate the need for this notification.
|
||||
# Consider using RDMA immediate data in decode side
|
||||
# to eliminate the need for this notification.
|
||||
# Consider including the first gen token from prefill in the notification
|
||||
|
||||
# Send completion notification
|
||||
@ -799,7 +801,11 @@ class MoRIIOConnectorMetadata(KVConnectorMetadata):
|
||||
def __repr__(self):
|
||||
return_str = ""
|
||||
for req_id, req_meta in self.reqs_to_recv.items():
|
||||
return_str += f"{req_id = },{req_meta.local_block_ids = },{req_meta.remote_block_ids = },{req_meta.remote_host = },{req_meta.remote_port = },{req_meta.remote_engine_id = },{req_meta.tp_size = }"
|
||||
return_str += (
|
||||
f"{req_id = },{req_meta.local_block_ids = },"
|
||||
f"{req_meta.remote_host = },{req_meta.remote_port = }"
|
||||
f"{req_meta.remote_engine_id = },{req_meta.tp_size = }"
|
||||
)
|
||||
return_str = f"MoRIIOConnectorMetadata:reqs_to_recv:{return_str},"
|
||||
|
||||
for req_id, expiry in self.reqs_to_send.items():
|
||||
@ -862,7 +868,7 @@ class MoRIIOConnector(KVConnectorBase_V1):
|
||||
self.connector_scheduler = None
|
||||
self.connector_worker = MoRIIOConnectorWorker(vllm_config, self.engine_id)
|
||||
logger.info(
|
||||
f"Initialized MoRIIO Connector,engine_id: {self.engine_id},role: {role.value}"
|
||||
"Initialized MoRIIO Connector,engine_id:{self.engine_id},role: {role.value}"
|
||||
)
|
||||
|
||||
############################################################
|
||||
@ -1067,7 +1073,6 @@ class MoRIIOConnectorScheduler:
|
||||
"decode_rank": self.dp_rank,
|
||||
"type": "remote_blocks",
|
||||
}
|
||||
# logger.debug(f"MoRIIO send notify block for prefill, {data= },{host= },{port= }")
|
||||
serialized_data = msgpack.dumps(data)
|
||||
self.paths[path].send(serialized_data)
|
||||
|
||||
@ -1139,7 +1144,8 @@ class MoRIIOConnectorScheduler:
|
||||
meta = MoRIIOConnectorMetadata()
|
||||
|
||||
if self.mode == MoRIIOMode.WRITE:
|
||||
# when async_load_kv finished, will add new reqs to scheduler_output.scheduled_new_reqs
|
||||
# when async_load_kv finished,
|
||||
# new reqs will be added to scheduler_output.scheduled_new_reqs
|
||||
|
||||
if get_role() == ROLE.CONSUMER:
|
||||
for new_req in scheduler_output.scheduled_new_reqs:
|
||||
@ -1161,7 +1167,8 @@ class MoRIIOConnectorScheduler:
|
||||
)
|
||||
if get_role() == ROLE.PRODUCER:
|
||||
# This is the logic for checking against chunked prefill.
|
||||
# When the last chunk is identified, it places the request metadata into the saving queue.
|
||||
# When the last chunk is identified,
|
||||
# It places the request metadata into the saving queue.
|
||||
|
||||
for i, req_id in enumerate(
|
||||
scheduler_output.scheduled_cached_reqs.req_ids
|
||||
@ -1376,11 +1383,10 @@ class MoRIIOConnectorWorker:
|
||||
self._ping_thread.start()
|
||||
|
||||
logger.info(
|
||||
f"Initializing MoRIIO Engine ,engine = {self.moriio_engine},role = {'producer' if self.is_producer else 'consumer'}"
|
||||
)
|
||||
logger.debug(
|
||||
f"{self.local_ip = },{self._rank = },{self._local_rank = },{self.local_kv_port = },{self.proxy_ip = },{self.local_ping_port = },{self.proxy_ping_port = }"
|
||||
f"Initializing MoRIIO Engine ,engine = {self.moriio_engine},"
|
||||
f"role = {'producer' if self.is_producer else 'consumer'}"
|
||||
)
|
||||
|
||||
# Agent.
|
||||
self.moriio_wrapper = MoRIIOWrapper(tp_rank=self.tp_rank, dp_rank=self.dp_rank)
|
||||
self.moriio_wrapper.set_moriio_engine(self.moriio_engine)
|
||||
@ -1568,7 +1574,8 @@ class MoRIIOConnectorWorker:
|
||||
|
||||
except ConnectionRefusedError:
|
||||
logger.info(
|
||||
f"Connection refused: {self.local_ip}:{self.local_ping_port} -> "
|
||||
f"Connection refused: {self.local_ip}:"
|
||||
f"{self.local_ping_port} -> "
|
||||
f"{self.proxy_ip}:{self.proxy_ping_port}"
|
||||
)
|
||||
retry_count += 1
|
||||
@ -1584,7 +1591,8 @@ class MoRIIOConnectorWorker:
|
||||
finally:
|
||||
if retry_count >= MoRIIOConstants.MAX_PING_RETRIES:
|
||||
logger.error(
|
||||
f"Max retries ({MoRIIOConstants.MAX_PING_RETRIES}) exceeded. Stopping ping loop."
|
||||
f"Max retries ({MoRIIOConstants.MAX_PING_RETRIES})"
|
||||
"exceeded. Stopping ping loop."
|
||||
)
|
||||
break
|
||||
|
||||
@ -1718,12 +1726,14 @@ class MoRIIOConnectorWorker:
|
||||
)
|
||||
if len(self.local_kv_cache_metadata) > 0:
|
||||
logger.warning(
|
||||
f"{len(self.local_kv_cache_metadata) = },maybe you didnt clear this buffer correctly"
|
||||
f"{len(self.local_kv_cache_metadata) = },"
|
||||
"maybe you didnt clear this buffer correctly"
|
||||
)
|
||||
self.local_kv_cache_metadata = []
|
||||
if len(self.remote_kv_cache_metadata) > 0:
|
||||
logger.warning(
|
||||
f" {len(self.remote_kv_cache_metadata) = },maybe you didnt clear this buffer correctly"
|
||||
f" {len(self.remote_kv_cache_metadata) = },"
|
||||
"maybe you didnt clear this buffer correctly"
|
||||
)
|
||||
self.remote_kv_cache_metadata = []
|
||||
|
||||
@ -1798,7 +1808,6 @@ class MoRIIOConnectorWorker:
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
"""Register the KV Cache data in moriio."""
|
||||
|
||||
# kv_caches,KEY layer name,VALUE cache tensor,(2,numblocks,blocksize,headnum,headsize)
|
||||
_, first_kv_cache = next(iter(kv_caches.items()))
|
||||
kv_elem_size = first_kv_cache.element_size()
|
||||
|
||||
@ -1836,8 +1845,6 @@ class MoRIIOConnectorWorker:
|
||||
self.block_shape = block_shape
|
||||
self.kv_element_size = kv_elem_size
|
||||
|
||||
# logger.info(f"Registering KV_Caches: {use_mla=}, {self.num_blocks=}, {block_shape=}, per_layer_kv_cache_shape={first_kv_cache.shape}")
|
||||
|
||||
self.dst_num_blocks[self.engine_id] = self.num_blocks
|
||||
self.kv_caches = kv_caches # layer name to kv cache
|
||||
kv_caches_base_addr = []
|
||||
@ -1849,7 +1856,6 @@ class MoRIIOConnectorWorker:
|
||||
if use_mla or self._use_flashinfer
|
||||
else cache_or_caches
|
||||
)
|
||||
# logger.debug(f"prepare register local kv cache tensor for local mori io engine,{len(cache_list) = },{kv_caches.keys() = }")
|
||||
for cache in cache_list:
|
||||
base_addr = cache.data_ptr()
|
||||
region_len = self.num_blocks * self.block_len
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user