fix format error

Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
inkcherry 2025-11-20 11:05:53 +00:00
parent 08cd2efbb6
commit a0d74ebf7f

View File

@ -802,8 +802,8 @@ class MoRIIOConnectorMetadata(KVConnectorMetadata):
return_str += f"{req_id = },{req_meta.local_block_ids = },{req_meta.remote_block_ids = },{req_meta.remote_host = },{req_meta.remote_port = },{req_meta.remote_engine_id = },{req_meta.tp_size = }"
return_str = f"MoRIIOConnectorMetadata:reqs_to_recv:{return_str},"
for req_id, req_meta in self.reqs_to_send.items():
return_str += f"{req_id = },{req_meta = }"
for req_id, expiry in self.reqs_to_send.items():
return_str += f"{req_id = },{expiry = }"
return_str = f"MoRIIOConnectorMetadata:reqs_to_send:{return_str},"
return return_str
@ -929,10 +929,11 @@ class MoRIIOConnector(KVConnectorBase_V1):
return self.connector_worker.get_finished()
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
assert self.connector_worker is not None
if self.mode == MoRIIOMode.WRITE:
if get_role() == ROLE.CONSUMER:
self.connector_worker.moriio_wrapper.async_wait_reqid()
assert self.connector_worker is not None
assert isinstance(self._connector_metadata, MoRIIOConnectorMetadata)
self.connector_worker.start_load_kv(self._connector_metadata)
@ -949,6 +950,9 @@ class MoRIIOConnector(KVConnectorBase_V1):
# Only producer/prefill saves KV Cache
if get_role() == ROLE.CONSUMER:
return
assert isinstance(self._connector_metadata, MoRIIOConnectorMetadata), (
"Connector metadata not initialized yet"
)
self.connector_worker.save_kv_layer(
self._connector_metadata, layer_name, kv_layer, attn_metadata, **kwargs
)
@ -1043,7 +1047,7 @@ class MoRIIOConnectorScheduler:
return len(request.prompt_token_ids) - 1 - num_computed_tokens, False
def send_notify_block(
self, req_id: str, block_notify_list: list[int] , host=None, port=None
self, req_id: str, block_notify_list: list[int], host=None, port=None
):
path = make_zmq_path("tcp", host, port)
if path not in self.paths:
@ -1137,6 +1141,12 @@ class MoRIIOConnectorScheduler:
for new_req in scheduler_output.scheduled_new_reqs:
red_id = new_req.req_id
local_block_ids = list(new_req.block_ids)
assert new_req.sampling_params is not None, (
f"sampling_params is None for req {new_req.req_id}"
)
assert hasattr(new_req.sampling_params, "extra_args"), (
f"sampling_params missing extra_args for req {new_req.req_id}"
)
kv_transfer_params = new_req.sampling_params.extra_args[
"kv_transfer_params"
]
@ -1391,8 +1401,6 @@ class MoRIIOConnectorWorker:
self.block_shape = None
self.kv_element_size = 0
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict)