fix format

Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
inkcherry 2025-12-01 06:25:53 +00:00
parent ea9b6871f6
commit 536668602c

View File

@ -544,23 +544,26 @@ class MoRIIOWrapper:
dp_rank: Data parallel rank
"""
def __init__(self, moriio_engine=None, tp_rank=0, dp_rank=0):
def __init__(
self,
moriio_engine: Optional["IOEngine"] = None,
tp_rank: int = 0,
dp_rank: int = 0,
):
self.tp_rank = tp_rank
self.dp_rank = dp_rank
self.moriio_engine = moriio_engine
self.remote_memory_metadata = None
self.local_memory_registered = False
self.local_memory_metadata = None
self.transfer_status = []
self.remote_engine_ip = None
self.notify_port = None
self.notify_sock = None
self.transfer_status: list[Any] = []
self.remote_engine_ip: str | None = None
self.notify_port: int | None = None
self.lock = threading.Lock()
self.done_req_ids = []
self.done_req_ids: list[str] = []
self.done_remote_allocate_req_dict: dict[str, RemoteAllocInfo] = {}
self.done_write_cache_req_ids = []
self.notify_thread = None
self.sock = None
self.done_write_cache_req_ids: list[str] = []
self.notify_thread: threading.Thread | None = None
self.sessions: list[IOEngine.Session] = []
self.paths: dict[str, zmq.Socket] = {}
@ -571,6 +574,7 @@ class MoRIIOWrapper:
self.moriio_engine = moriio_engine
def set_backend_type(self, backend_type):
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
qp_per_transfer = int(os.getenv("VLLM_MORI_QP_PER_TRANSFER", "1"))
post_batch_size = int(os.getenv("VLLM_MORI_POST_BATCH_SIZE", "-1"))
num_worker_threads = int(os.getenv("VLLM_MORI_NUM_WORKERS", "1"))
@ -584,20 +588,26 @@ class MoRIIOWrapper:
self.moriio_engine.create_backend(backend_type, rdma_cfg)
def get_agent_metadata(self):
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
engine_metadata = self.moriio_engine.get_engine_desc()
engine_metadata_packed = engine_metadata.pack()
return engine_metadata_packed
def register_remote_engine(self, remote_packed_engine_metadata):
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
consumer_engine_metadata = EngineDesc.unpack(remote_packed_engine_metadata)
self.moriio_engine.register_remote_engine(consumer_engine_metadata)
return consumer_engine_metadata.key
def register_local_tensor(self, tensor: torch.Tensor):
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
try:
self.local_memory_metadata = self.moriio_engine.register_torch_tensor(
tensor
)
assert self.local_memory_metadata is not None, (
"register_torch_tensor returned None"
)
local_memory_metadata_packed = self.local_memory_metadata.pack()
except Exception as e:
raise MoRIIOError(f"Failed to register local memory: {e}") from e
@ -608,6 +618,7 @@ class MoRIIOWrapper:
return MemoryDesc.unpack(packed_memory_metadata)
def build_session(self, local_memory_metadata, remote_memory_metadata):
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
return self.moriio_engine.create_session(
local_memory_metadata, remote_memory_metadata
)
@ -616,7 +627,7 @@ class MoRIIOWrapper:
self, transfer_size_byte, local_offset=0, remote_offset=0, session=None
):
assert self.local_memory_registered, "You have not register local memory data!"
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
transfer_status = session.batch_read(
local_offset,
remote_offset,
@ -630,6 +641,7 @@ class MoRIIOWrapper:
self, transfer_size_byte, local_offset=0, remote_offset=0, session=None
):
assert self.local_memory_registered, "You have not register local memory data!"
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
write_uid = self.moriio_engine.allocate_transfer_uid()
transfer_status = session.batch_write(
@ -642,7 +654,7 @@ class MoRIIOWrapper:
self, transfer_size_byte, local_offset=0, remote_offset=0, sess_idx=0
):
assert self.local_memory_registered, "You have not register local memory data!"
assert self.moriio_engine is not None, "MoRIIO engine must be set first"
transfer_status = self.sessions[sess_idx].write(
local_offset,
remote_offset,
@ -1052,7 +1064,6 @@ class MoRIIOConnectorScheduler:
set_role(ROLE.CONSUMER)
# Reqs to send and their expiration time
self._reqs_need_send: dict[ReqId, float] = {}
self.sock = None
self.paths: dict[str, zmq.Socket] = {}
def get_num_new_matched_tokens(