Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
inkcherry 2025-11-20 11:30:48 +00:00
parent a0d74ebf7f
commit 9b90f5ddb2

View File

@ -784,7 +784,7 @@ class ReqMeta:
remote_host: str remote_host: str
remote_port: int remote_port: int
remote_handshake_port: int remote_handshake_port: int
remote_notify_port: int | None remote_notify_port: int
remote_engine_id: str remote_engine_id: str
tp_size: int tp_size: int
remote_dp_size: int remote_dp_size: int
@ -821,7 +821,7 @@ class MoRIIOConnectorMetadata(KVConnectorMetadata):
remote_host=kv_transfer_params["remote_host"], remote_host=kv_transfer_params["remote_host"],
remote_port=kv_transfer_params["remote_port"], remote_port=kv_transfer_params["remote_port"],
remote_handshake_port=kv_transfer_params["remote_handshake_port"], remote_handshake_port=kv_transfer_params["remote_handshake_port"],
remote_notify_port=kv_transfer_params.get("remote_notify_port"), remote_notify_port=kv_transfer_params["remote_notify_port"],
tp_size=kv_transfer_params.get("tp_size", 1), tp_size=kv_transfer_params.get("tp_size", 1),
remote_dp_size=kv_transfer_params.get("remote_dp_size", 1), remote_dp_size=kv_transfer_params.get("remote_dp_size", 1),
) )
@ -950,6 +950,10 @@ class MoRIIOConnector(KVConnectorBase_V1):
# Only producer/prefill saves KV Cache # Only producer/prefill saves KV Cache
if get_role() == ROLE.CONSUMER: if get_role() == ROLE.CONSUMER:
return return
assert self.connector_worker is not None, (
"save_kv_layer called on scheduler role"
)
assert isinstance(self._connector_metadata, MoRIIOConnectorMetadata), ( assert isinstance(self._connector_metadata, MoRIIOConnectorMetadata), (
"Connector metadata not initialized yet" "Connector metadata not initialized yet"
) )
@ -1674,7 +1678,7 @@ class MoRIIOConnectorWorker:
remote_tp_size: int, remote_tp_size: int,
expected_engine_id: str, expected_engine_id: str,
remote_dp_rank: int = 0, remote_dp_rank: int = 0,
) -> dict[int, str]: ) -> set[str]:
"""Do a MoRIIO handshake with a remote instance.""" """Do a MoRIIO handshake with a remote instance."""
start_time = time.perf_counter() start_time = time.perf_counter()
@ -2085,6 +2089,7 @@ class MoRIIOConnectorWorker:
def _write_blocks_for_req(self, req_id: str, meta: ReqMeta, layer_name, kv_layer): def _write_blocks_for_req(self, req_id: str, meta: ReqMeta, layer_name, kv_layer):
# logger.debug(f"write block for req {req_id} to remote engine " # logger.debug(f"write block for req {req_id} to remote engine "
# f"{meta.remote_engine_id}") # f"{meta.remote_engine_id}")
self.schedule_write_blocks( self.schedule_write_blocks(
request_id=req_id, request_id=req_id,
dst_engine_id=meta.remote_engine_id, dst_engine_id=meta.remote_engine_id,
@ -2189,6 +2194,7 @@ class MoRIIOConnectorWorker:
Returns: Returns:
Tuple of (local_offsets, remote_offsets, transfer_sizes) Tuple of (local_offsets, remote_offsets, transfer_sizes)
""" """
assert self.kv_cache_shape is not None, "KV caches shape not initialized"
is_mla = len(self.kv_cache_shape) == 3 is_mla = len(self.kv_cache_shape) == 3
stride = self.kv_caches[layer_name].stride() stride = self.kv_caches[layer_name].stride()
sz = self.kv_caches[layer_name].element_size() sz = self.kv_caches[layer_name].element_size()