From e0885e52d91b501dc2a5f6001c36a6628e4d7eed Mon Sep 17 00:00:00 2001 From: inkcherry Date: Fri, 21 Nov 2025 03:55:32 +0000 Subject: [PATCH] break long line Signed-off-by: inkcherry --- .../kv_connector/v1/moriio_connector.py | 50 +++++++++++-------- 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 61b83c03b2328..62a8ddc5d1a44 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -73,7 +73,6 @@ class MoRIIOConstants: try: - import mori from mori.io import ( BackendType, EngineDesc, @@ -260,7 +259,8 @@ class MoRIIOConfig: class MoRIIOWriter: - """Handles write operations for KV cache transfers. Implements distributed KV cache transfer using the MoRIIO library + """Handles write operations for KV cache transfers. + Implements distributed KV cache transfer using the MoRIIO library for RDMA-based communication between prefill and decode instances.""" def __init__(self, worker: "MoRIIOConnectorWorker"): @@ -400,7 +400,9 @@ class MoRIIOWriter: return # Wait for CUDA event - # The attention computation of the current layer cannot overlap with the kv transfer task, otherwise it will cause precision issues. + # The attention computation of the current layer cannot + # overlap with the kv transfer task, + # otherwise it will cause precision issues. # This event is used to synchronize the kv transfer and computation tasks. task.event.synchronize() @@ -497,8 +499,8 @@ class MoRIIOWriter: remote_port = task.remote_notify_port + get_port_offset( request_info.decode_dp_rank, self.worker.tp_rank ) - # TODO: - # Consider using RDMA immediate data in decode side to eliminate the need for this notification. + # Consider using RDMA immediate data in decode side + # to eliminate the need for this notification. # Consider including the first gen token from prefill in the notification # Send completion notification @@ -799,7 +801,11 @@ class MoRIIOConnectorMetadata(KVConnectorMetadata): def __repr__(self): return_str = "" for req_id, req_meta in self.reqs_to_recv.items(): - return_str += f"{req_id = },{req_meta.local_block_ids = },{req_meta.remote_block_ids = },{req_meta.remote_host = },{req_meta.remote_port = },{req_meta.remote_engine_id = },{req_meta.tp_size = }" + return_str += ( + f"{req_id = },{req_meta.local_block_ids = }," + f"{req_meta.remote_host = },{req_meta.remote_port = }" + f"{req_meta.remote_engine_id = },{req_meta.tp_size = }" + ) return_str = f"MoRIIOConnectorMetadata:reqs_to_recv:{return_str}," for req_id, expiry in self.reqs_to_send.items(): @@ -862,7 +868,7 @@ class MoRIIOConnector(KVConnectorBase_V1): self.connector_scheduler = None self.connector_worker = MoRIIOConnectorWorker(vllm_config, self.engine_id) logger.info( - f"Initialized MoRIIO Connector,engine_id: {self.engine_id},role: {role.value}" + "Initialized MoRIIO Connector,engine_id:{self.engine_id},role: {role.value}" ) ############################################################ @@ -1067,7 +1073,6 @@ class MoRIIOConnectorScheduler: "decode_rank": self.dp_rank, "type": "remote_blocks", } - # logger.debug(f"MoRIIO send notify block for prefill, {data= },{host= },{port= }") serialized_data = msgpack.dumps(data) self.paths[path].send(serialized_data) @@ -1139,7 +1144,8 @@ class MoRIIOConnectorScheduler: meta = MoRIIOConnectorMetadata() if self.mode == MoRIIOMode.WRITE: - # when async_load_kv finished, will add new reqs to scheduler_output.scheduled_new_reqs + # when async_load_kv finished, + # new reqs will be added to scheduler_output.scheduled_new_reqs if get_role() == ROLE.CONSUMER: for new_req in scheduler_output.scheduled_new_reqs: @@ -1161,7 +1167,8 @@ class MoRIIOConnectorScheduler: ) if get_role() == ROLE.PRODUCER: # This is the logic for checking against chunked prefill. - # When the last chunk is identified, it places the request metadata into the saving queue. + # When the last chunk is identified, + # It places the request metadata into the saving queue. for i, req_id in enumerate( scheduler_output.scheduled_cached_reqs.req_ids @@ -1376,11 +1383,10 @@ class MoRIIOConnectorWorker: self._ping_thread.start() logger.info( - f"Initializing MoRIIO Engine ,engine = {self.moriio_engine},role = {'producer' if self.is_producer else 'consumer'}" - ) - logger.debug( - f"{self.local_ip = },{self._rank = },{self._local_rank = },{self.local_kv_port = },{self.proxy_ip = },{self.local_ping_port = },{self.proxy_ping_port = }" + f"Initializing MoRIIO Engine ,engine = {self.moriio_engine}," + f"role = {'producer' if self.is_producer else 'consumer'}" ) + # Agent. self.moriio_wrapper = MoRIIOWrapper(tp_rank=self.tp_rank, dp_rank=self.dp_rank) self.moriio_wrapper.set_moriio_engine(self.moriio_engine) @@ -1568,7 +1574,8 @@ class MoRIIOConnectorWorker: except ConnectionRefusedError: logger.info( - f"Connection refused: {self.local_ip}:{self.local_ping_port} -> " + f"Connection refused: {self.local_ip}:" + f"{self.local_ping_port} -> " f"{self.proxy_ip}:{self.proxy_ping_port}" ) retry_count += 1 @@ -1584,7 +1591,8 @@ class MoRIIOConnectorWorker: finally: if retry_count >= MoRIIOConstants.MAX_PING_RETRIES: logger.error( - f"Max retries ({MoRIIOConstants.MAX_PING_RETRIES}) exceeded. Stopping ping loop." + f"Max retries ({MoRIIOConstants.MAX_PING_RETRIES})" + "exceeded. Stopping ping loop." ) break @@ -1718,12 +1726,14 @@ class MoRIIOConnectorWorker: ) if len(self.local_kv_cache_metadata) > 0: logger.warning( - f"{len(self.local_kv_cache_metadata) = },maybe you didnt clear this buffer correctly" + f"{len(self.local_kv_cache_metadata) = }," + "maybe you didnt clear this buffer correctly" ) self.local_kv_cache_metadata = [] if len(self.remote_kv_cache_metadata) > 0: logger.warning( - f" {len(self.remote_kv_cache_metadata) = },maybe you didnt clear this buffer correctly" + f" {len(self.remote_kv_cache_metadata) = }," + "maybe you didnt clear this buffer correctly" ) self.remote_kv_cache_metadata = [] @@ -1798,7 +1808,6 @@ class MoRIIOConnectorWorker: def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in moriio.""" - # kv_caches,KEY layer name,VALUE cache tensor,(2,numblocks,blocksize,headnum,headsize) _, first_kv_cache = next(iter(kv_caches.items())) kv_elem_size = first_kv_cache.element_size() @@ -1836,8 +1845,6 @@ class MoRIIOConnectorWorker: self.block_shape = block_shape self.kv_element_size = kv_elem_size - # logger.info(f"Registering KV_Caches: {use_mla=}, {self.num_blocks=}, {block_shape=}, per_layer_kv_cache_shape={first_kv_cache.shape}") - self.dst_num_blocks[self.engine_id] = self.num_blocks self.kv_caches = kv_caches # layer name to kv cache kv_caches_base_addr = [] @@ -1849,7 +1856,6 @@ class MoRIIOConnectorWorker: if use_mla or self._use_flashinfer else cache_or_caches ) - # logger.debug(f"prepare register local kv cache tensor for local mori io engine,{len(cache_list) = },{kv_caches.keys() = }") for cache in cache_list: base_addr = cache.data_ptr() region_len = self.num_blocks * self.block_len