break long line

Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
inkcherry 2025-11-21 03:55:32 +00:00
parent f75eecde0a
commit e0885e52d9

View File

@ -73,7 +73,6 @@ class MoRIIOConstants:
try:
import mori
from mori.io import (
BackendType,
EngineDesc,
@ -260,7 +259,8 @@ class MoRIIOConfig:
class MoRIIOWriter:
"""Handles write operations for KV cache transfers. Implements distributed KV cache transfer using the MoRIIO library
"""Handles write operations for KV cache transfers.
Implements distributed KV cache transfer using the MoRIIO library
for RDMA-based communication between prefill and decode instances."""
def __init__(self, worker: "MoRIIOConnectorWorker"):
@ -400,7 +400,9 @@ class MoRIIOWriter:
return
# Wait for CUDA event
# The attention computation of the current layer cannot overlap with the kv transfer task, otherwise it will cause precision issues.
# The attention computation of the current layer cannot
# overlap with the kv transfer task,
# otherwise it will cause precision issues.
# This event is used to synchronize the kv transfer and computation tasks.
task.event.synchronize()
@ -497,8 +499,8 @@ class MoRIIOWriter:
remote_port = task.remote_notify_port + get_port_offset(
request_info.decode_dp_rank, self.worker.tp_rank
)
# TODO:
# Consider using RDMA immediate data in decode side to eliminate the need for this notification.
# Consider using RDMA immediate data in decode side
# to eliminate the need for this notification.
# Consider including the first gen token from prefill in the notification
# Send completion notification
@ -799,7 +801,11 @@ class MoRIIOConnectorMetadata(KVConnectorMetadata):
def __repr__(self):
return_str = ""
for req_id, req_meta in self.reqs_to_recv.items():
return_str += f"{req_id = },{req_meta.local_block_ids = },{req_meta.remote_block_ids = },{req_meta.remote_host = },{req_meta.remote_port = },{req_meta.remote_engine_id = },{req_meta.tp_size = }"
return_str += (
f"{req_id = },{req_meta.local_block_ids = },"
f"{req_meta.remote_host = },{req_meta.remote_port = }"
f"{req_meta.remote_engine_id = },{req_meta.tp_size = }"
)
return_str = f"MoRIIOConnectorMetadata:reqs_to_recv:{return_str},"
for req_id, expiry in self.reqs_to_send.items():
@ -862,7 +868,7 @@ class MoRIIOConnector(KVConnectorBase_V1):
self.connector_scheduler = None
self.connector_worker = MoRIIOConnectorWorker(vllm_config, self.engine_id)
logger.info(
f"Initialized MoRIIO Connector,engine_id: {self.engine_id},role: {role.value}"
"Initialized MoRIIO Connector,engine_id:{self.engine_id},role: {role.value}"
)
############################################################
@ -1067,7 +1073,6 @@ class MoRIIOConnectorScheduler:
"decode_rank": self.dp_rank,
"type": "remote_blocks",
}
# logger.debug(f"MoRIIO send notify block for prefill, {data= },{host= },{port= }")
serialized_data = msgpack.dumps(data)
self.paths[path].send(serialized_data)
@ -1139,7 +1144,8 @@ class MoRIIOConnectorScheduler:
meta = MoRIIOConnectorMetadata()
if self.mode == MoRIIOMode.WRITE:
# when async_load_kv finished, will add new reqs to scheduler_output.scheduled_new_reqs
# when async_load_kv finished,
# new reqs will be added to scheduler_output.scheduled_new_reqs
if get_role() == ROLE.CONSUMER:
for new_req in scheduler_output.scheduled_new_reqs:
@ -1161,7 +1167,8 @@ class MoRIIOConnectorScheduler:
)
if get_role() == ROLE.PRODUCER:
# This is the logic for checking against chunked prefill.
# When the last chunk is identified, it places the request metadata into the saving queue.
# When the last chunk is identified,
# It places the request metadata into the saving queue.
for i, req_id in enumerate(
scheduler_output.scheduled_cached_reqs.req_ids
@ -1376,11 +1383,10 @@ class MoRIIOConnectorWorker:
self._ping_thread.start()
logger.info(
f"Initializing MoRIIO Engine ,engine = {self.moriio_engine},role = {'producer' if self.is_producer else 'consumer'}"
)
logger.debug(
f"{self.local_ip = },{self._rank = },{self._local_rank = },{self.local_kv_port = },{self.proxy_ip = },{self.local_ping_port = },{self.proxy_ping_port = }"
f"Initializing MoRIIO Engine ,engine = {self.moriio_engine},"
f"role = {'producer' if self.is_producer else 'consumer'}"
)
# Agent.
self.moriio_wrapper = MoRIIOWrapper(tp_rank=self.tp_rank, dp_rank=self.dp_rank)
self.moriio_wrapper.set_moriio_engine(self.moriio_engine)
@ -1568,7 +1574,8 @@ class MoRIIOConnectorWorker:
except ConnectionRefusedError:
logger.info(
f"Connection refused: {self.local_ip}:{self.local_ping_port} -> "
f"Connection refused: {self.local_ip}:"
f"{self.local_ping_port} -> "
f"{self.proxy_ip}:{self.proxy_ping_port}"
)
retry_count += 1
@ -1584,7 +1591,8 @@ class MoRIIOConnectorWorker:
finally:
if retry_count >= MoRIIOConstants.MAX_PING_RETRIES:
logger.error(
f"Max retries ({MoRIIOConstants.MAX_PING_RETRIES}) exceeded. Stopping ping loop."
f"Max retries ({MoRIIOConstants.MAX_PING_RETRIES})"
"exceeded. Stopping ping loop."
)
break
@ -1718,12 +1726,14 @@ class MoRIIOConnectorWorker:
)
if len(self.local_kv_cache_metadata) > 0:
logger.warning(
f"{len(self.local_kv_cache_metadata) = },maybe you didnt clear this buffer correctly"
f"{len(self.local_kv_cache_metadata) = },"
"maybe you didnt clear this buffer correctly"
)
self.local_kv_cache_metadata = []
if len(self.remote_kv_cache_metadata) > 0:
logger.warning(
f" {len(self.remote_kv_cache_metadata) = },maybe you didnt clear this buffer correctly"
f" {len(self.remote_kv_cache_metadata) = },"
"maybe you didnt clear this buffer correctly"
)
self.remote_kv_cache_metadata = []
@ -1798,7 +1808,6 @@ class MoRIIOConnectorWorker:
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data in moriio."""
# kv_caches,KEY layer name,VALUE cache tensor,(2,numblocks,blocksize,headnum,headsize)
_, first_kv_cache = next(iter(kv_caches.items()))
kv_elem_size = first_kv_cache.element_size()
@ -1836,8 +1845,6 @@ class MoRIIOConnectorWorker:
self.block_shape = block_shape
self.kv_element_size = kv_elem_size
# logger.info(f"Registering KV_Caches: {use_mla=}, {self.num_blocks=}, {block_shape=}, per_layer_kv_cache_shape={first_kv_cache.shape}")
self.dst_num_blocks[self.engine_id] = self.num_blocks
self.kv_caches = kv_caches # layer name to kv cache
kv_caches_base_addr = []
@ -1849,7 +1856,6 @@ class MoRIIOConnectorWorker:
if use_mla or self._use_flashinfer
else cache_or_caches
)
# logger.debug(f"prepare register local kv cache tensor for local mori io engine,{len(cache_list) = },{kv_caches.keys() = }")
for cache in cache_list:
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len