Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
inkcherry 2025-11-20 11:30:48 +00:00
parent a0d74ebf7f
commit 9b90f5ddb2

View File

@ -784,7 +784,7 @@ class ReqMeta:
remote_host: str
remote_port: int
remote_handshake_port: int
remote_notify_port: int | None
remote_notify_port: int
remote_engine_id: str
tp_size: int
remote_dp_size: int
@ -821,7 +821,7 @@ class MoRIIOConnectorMetadata(KVConnectorMetadata):
remote_host=kv_transfer_params["remote_host"],
remote_port=kv_transfer_params["remote_port"],
remote_handshake_port=kv_transfer_params["remote_handshake_port"],
remote_notify_port=kv_transfer_params.get("remote_notify_port"),
remote_notify_port=kv_transfer_params["remote_notify_port"],
tp_size=kv_transfer_params.get("tp_size", 1),
remote_dp_size=kv_transfer_params.get("remote_dp_size", 1),
)
@ -950,6 +950,10 @@ class MoRIIOConnector(KVConnectorBase_V1):
# Only producer/prefill saves KV Cache
if get_role() == ROLE.CONSUMER:
return
assert self.connector_worker is not None, (
"save_kv_layer called on scheduler role"
)
assert isinstance(self._connector_metadata, MoRIIOConnectorMetadata), (
"Connector metadata not initialized yet"
)
@ -1674,7 +1678,7 @@ class MoRIIOConnectorWorker:
remote_tp_size: int,
expected_engine_id: str,
remote_dp_rank: int = 0,
) -> dict[int, str]:
) -> set[str]:
"""Do a MoRIIO handshake with a remote instance."""
start_time = time.perf_counter()
@ -2085,6 +2089,7 @@ class MoRIIOConnectorWorker:
def _write_blocks_for_req(self, req_id: str, meta: ReqMeta, layer_name, kv_layer):
# logger.debug(f"write block for req {req_id} to remote engine "
# f"{meta.remote_engine_id}")
self.schedule_write_blocks(
request_id=req_id,
dst_engine_id=meta.remote_engine_id,
@ -2189,6 +2194,7 @@ class MoRIIOConnectorWorker:
Returns:
Tuple of (local_offsets, remote_offsets, transfer_sizes)
"""
assert self.kv_cache_shape is not None, "KV caches shape not initialized"
is_mla = len(self.kv_cache_shape) == 3
stride = self.kv_caches[layer_name].stride()
sz = self.kv_caches[layer_name].element_size()