From 536668602c3413ca137eca5d8d48b526abb895b4 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Mon, 1 Dec 2025 06:25:53 +0000 Subject: [PATCH] fix format Signed-off-by: inkcherry --- .../kv_connector/v1/moriio_connector.py | 35 ++++++++++++------- 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index b1063f5487955..1299faff78336 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -544,23 +544,26 @@ class MoRIIOWrapper: dp_rank: Data parallel rank """ - def __init__(self, moriio_engine=None, tp_rank=0, dp_rank=0): + def __init__( + self, + moriio_engine: Optional["IOEngine"] = None, + tp_rank: int = 0, + dp_rank: int = 0, + ): self.tp_rank = tp_rank self.dp_rank = dp_rank self.moriio_engine = moriio_engine self.remote_memory_metadata = None self.local_memory_registered = False self.local_memory_metadata = None - self.transfer_status = [] - self.remote_engine_ip = None - self.notify_port = None - self.notify_sock = None + self.transfer_status: list[Any] = [] + self.remote_engine_ip: str | None = None + self.notify_port: int | None = None self.lock = threading.Lock() - self.done_req_ids = [] + self.done_req_ids: list[str] = [] self.done_remote_allocate_req_dict: dict[str, RemoteAllocInfo] = {} - self.done_write_cache_req_ids = [] - self.notify_thread = None - self.sock = None + self.done_write_cache_req_ids: list[str] = [] + self.notify_thread: threading.Thread | None = None self.sessions: list[IOEngine.Session] = [] self.paths: dict[str, zmq.Socket] = {} @@ -571,6 +574,7 @@ class MoRIIOWrapper: self.moriio_engine = moriio_engine def set_backend_type(self, backend_type): + assert self.moriio_engine is not None, "MoRIIO engine must be set first" qp_per_transfer = int(os.getenv("VLLM_MORI_QP_PER_TRANSFER", "1")) post_batch_size = int(os.getenv("VLLM_MORI_POST_BATCH_SIZE", "-1")) num_worker_threads = int(os.getenv("VLLM_MORI_NUM_WORKERS", "1")) @@ -584,20 +588,26 @@ class MoRIIOWrapper: self.moriio_engine.create_backend(backend_type, rdma_cfg) def get_agent_metadata(self): + assert self.moriio_engine is not None, "MoRIIO engine must be set first" engine_metadata = self.moriio_engine.get_engine_desc() engine_metadata_packed = engine_metadata.pack() return engine_metadata_packed def register_remote_engine(self, remote_packed_engine_metadata): + assert self.moriio_engine is not None, "MoRIIO engine must be set first" consumer_engine_metadata = EngineDesc.unpack(remote_packed_engine_metadata) self.moriio_engine.register_remote_engine(consumer_engine_metadata) return consumer_engine_metadata.key def register_local_tensor(self, tensor: torch.Tensor): + assert self.moriio_engine is not None, "MoRIIO engine must be set first" try: self.local_memory_metadata = self.moriio_engine.register_torch_tensor( tensor ) + assert self.local_memory_metadata is not None, ( + "register_torch_tensor returned None" + ) local_memory_metadata_packed = self.local_memory_metadata.pack() except Exception as e: raise MoRIIOError(f"Failed to register local memory: {e}") from e @@ -608,6 +618,7 @@ class MoRIIOWrapper: return MemoryDesc.unpack(packed_memory_metadata) def build_session(self, local_memory_metadata, remote_memory_metadata): + assert self.moriio_engine is not None, "MoRIIO engine must be set first" return self.moriio_engine.create_session( local_memory_metadata, remote_memory_metadata ) @@ -616,7 +627,7 @@ class MoRIIOWrapper: self, transfer_size_byte, local_offset=0, remote_offset=0, session=None ): assert self.local_memory_registered, "You have not register local memory data!" - + assert self.moriio_engine is not None, "MoRIIO engine must be set first" transfer_status = session.batch_read( local_offset, remote_offset, @@ -630,6 +641,7 @@ class MoRIIOWrapper: self, transfer_size_byte, local_offset=0, remote_offset=0, session=None ): assert self.local_memory_registered, "You have not register local memory data!" + assert self.moriio_engine is not None, "MoRIIO engine must be set first" write_uid = self.moriio_engine.allocate_transfer_uid() transfer_status = session.batch_write( @@ -642,7 +654,7 @@ class MoRIIOWrapper: self, transfer_size_byte, local_offset=0, remote_offset=0, sess_idx=0 ): assert self.local_memory_registered, "You have not register local memory data!" - + assert self.moriio_engine is not None, "MoRIIO engine must be set first" transfer_status = self.sessions[sess_idx].write( local_offset, remote_offset, @@ -1052,7 +1064,6 @@ class MoRIIOConnectorScheduler: set_role(ROLE.CONSUMER) # Reqs to send and their expiration time self._reqs_need_send: dict[ReqId, float] = {} - self.sock = None self.paths: dict[str, zmq.Socket] = {} def get_num_new_matched_tokens(