Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
inkcherry 2025-12-02 06:44:36 +00:00
parent 532d8a7453
commit e52ba22429

View File

@ -138,6 +138,7 @@ class ROLE(Enum):
CONSUMER = "consumer"
NOTINIT = "notinit"
class MoRIIOAgentMetadata(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
@ -437,7 +438,9 @@ class MoRIIOWriter:
)
# Get or create sessions
sessions, remote_moriio_meta = self.worker._get_built_session(task.dst_engine_id)
sessions, remote_moriio_meta = self.worker._get_built_session(
task.dst_engine_id
)
# Prepare transfer plan
plan = self._prepare_transfer_plan(task, request_info, remote_moriio_meta)
@ -449,7 +452,10 @@ class MoRIIOWriter:
self._finalize_if_complete(task, request_info)
def _prepare_transfer_plan(
self, task: WriteTask, request_info: RemoteAllocInfo, remote_moriio_meta: MoRIIOAgentMetadata
self,
task: WriteTask,
request_info: RemoteAllocInfo,
remote_moriio_meta: MoRIIOAgentMetadata,
) -> LayerTransferPlan:
"""Prepare the transfer plan for a layer.
@ -463,7 +469,10 @@ class MoRIIOWriter:
# Compute offsets if not cached
if request_info.transfer_offset is None:
offsets = self.worker._compute_block_transfer_offsets(
task.layer_name, task.local_block_ids, request_info.block_ids, remote_moriio_meta
task.layer_name,
task.local_block_ids,
request_info.block_ids,
remote_moriio_meta,
)
request_info.transfer_offset = offsets
@ -827,8 +836,6 @@ class MoRIIOWrapper:
self.paths.clear()
@dataclass
class ReqMeta:
"""Metadata for a single request."""
@ -1603,7 +1610,9 @@ class MoRIIOConnectorWorker:
)
)
self.built_write_session[remote_engine_id] = cur_remote_engine_sessions
return self.built_write_session[remote_engine_id], self.remote_moriio_metadata[remote_engine_id]
return self.built_write_session[remote_engine_id], self.remote_moriio_metadata[
remote_engine_id
]
def _ping(self, zmq_context):
http_request_address = f"http://{self.request_address}/v1/completions"
@ -1803,7 +1812,7 @@ class MoRIIOConnectorWorker:
self.layer_name_to_remote_kv_cache_metadata[expected_engine_id] = (
msgpack.loads(buf)
)
self.remote_moriio_metadata[expected_engine_id]=metadata
self.remote_moriio_metadata[expected_engine_id] = metadata
setup_agent_time = time.perf_counter()
logger.debug(
"MoRIIO handshake: add agent took: %s",
@ -2251,7 +2260,7 @@ class MoRIIOConnectorWorker:
_, blknum, blksize, hn, hs = self.kv_cache_shape
local_ktov_stride = stride[0]
block_stride = stride[1]
remote_ktov_stride = block_stride*remote_moriio_meta.num_blocks
remote_ktov_stride = block_stride * remote_moriio_meta.num_blocks
transfer_size_byte = blksize * hn * hs * sz
per_block = 1 if is_mla else 2