Fix the issue of num_block inconsistency in non-MLA scenarios

Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
inkcherry 2025-12-02 06:08:03 +00:00
parent 8a2c136c8f
commit 532d8a7453

View File

@ -138,6 +138,19 @@ class ROLE(Enum):
CONSUMER = "consumer"
NOTINIT = "notinit"
class MoRIIOAgentMetadata(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
# required for @cached_property.d
dict=True,
):
engine_id: str
agent_metadata: bytes
kv_caches_base_addr: list[int]
num_blocks: int
block_len: int
attn_backend_name: str
class RoleManager:
"""Manages role state across the connector."""
@ -424,10 +437,10 @@ class MoRIIOWriter:
)
# Get or create sessions
sessions = self.worker._get_built_session(task.dst_engine_id)
sessions, remote_moriio_meta = self.worker._get_built_session(task.dst_engine_id)
# Prepare transfer plan
plan = self._prepare_transfer_plan(task, request_info)
plan = self._prepare_transfer_plan(task, request_info, remote_moriio_meta)
# Execute transfer
self._do_layer_write(plan, sessions)
@ -436,7 +449,7 @@ class MoRIIOWriter:
self._finalize_if_complete(task, request_info)
def _prepare_transfer_plan(
self, task: WriteTask, request_info: RemoteAllocInfo
self, task: WriteTask, request_info: RemoteAllocInfo, remote_moriio_meta: MoRIIOAgentMetadata
) -> LayerTransferPlan:
"""Prepare the transfer plan for a layer.
@ -450,7 +463,7 @@ class MoRIIOWriter:
# Compute offsets if not cached
if request_info.transfer_offset is None:
offsets = self.worker._compute_block_transfer_offsets(
task.layer_name, task.local_block_ids, request_info.block_ids
task.layer_name, task.local_block_ids, request_info.block_ids, remote_moriio_meta
)
request_info.transfer_offset = offsets
@ -814,18 +827,6 @@ class MoRIIOWrapper:
self.paths.clear()
class MoRIIOAgentMetadata(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
# required for @cached_property.d
dict=True,
):
engine_id: str
agent_metadata: bytes
kv_caches_base_addr: list[int]
num_blocks: int
block_len: int
attn_backend_name: str
@dataclass
@ -1461,6 +1462,7 @@ class MoRIIOConnectorWorker:
self.layer_name_to_remote_kv_cache_metadata: dict[str, dict[str, list[Any]]] = (
dict()
)
self.remote_moriio_metadata: dict[EngineId, MoRIIOAgentMetadata] = {}
self.slot_size_bytes = 0
self.load_ready_flag: dict[str, bool] = {}
@ -1585,10 +1587,10 @@ class MoRIIOConnectorWorker:
if remote_engine_id not in self.built_write_session:
cur_remote_engine_sessions = []
for ln, local_meta in self.layer_name_to_local_kv_cache_metadata.items():
unpcaked_local_memory_meta = (
unpacked_local_memory_meta = (
self.moriio_wrapper.get_unpack_memory_metadata(local_meta[0])
)
unpcaked_remote_memory_meta = (
unpacked_remote_memory_meta = (
self.moriio_wrapper.get_unpack_memory_metadata(
self.layer_name_to_remote_kv_cache_metadata[remote_engine_id][
ln
@ -1597,11 +1599,11 @@ class MoRIIOConnectorWorker:
)
cur_remote_engine_sessions.append(
self.moriio_wrapper.build_session(
unpcaked_local_memory_meta, unpcaked_remote_memory_meta
unpacked_local_memory_meta, unpacked_remote_memory_meta
)
)
self.built_write_session[remote_engine_id] = cur_remote_engine_sessions
return self.built_write_session[remote_engine_id]
return self.built_write_session[remote_engine_id], self.remote_moriio_metadata[remote_engine_id]
def _ping(self, zmq_context):
http_request_address = f"http://{self.request_address}/v1/completions"
@ -1801,7 +1803,7 @@ class MoRIIOConnectorWorker:
self.layer_name_to_remote_kv_cache_metadata[expected_engine_id] = (
msgpack.loads(buf)
)
self.remote_moriio_metadata[expected_engine_id]=metadata
setup_agent_time = time.perf_counter()
logger.debug(
"MoRIIO handshake: add agent took: %s",
@ -2225,6 +2227,7 @@ class MoRIIOConnectorWorker:
layer_name: str,
local_block_ids: list[int],
remote_block_ids: list[int],
remote_moriio_meta: MoRIIOAgentMetadata,
) -> tuple[list[int], list[int], list[int]]:
"""Compute transfer offsets for block data.
@ -2232,7 +2235,7 @@ class MoRIIOConnectorWorker:
layer_name: Name of the layer to transfer
local_block_ids: IDs of local blocks
remote_block_ids: IDs of remote blocks
remote_moriio_meta: Metadata of the remote MoRIIO agent
Returns:
Tuple of (local_offsets, remote_offsets, transfer_sizes)
"""
@ -2246,8 +2249,9 @@ class MoRIIOConnectorWorker:
block_stride = stride[0]
else:
_, blknum, blksize, hn, hs = self.kv_cache_shape
ktov_stride = stride[0]
local_ktov_stride = stride[0]
block_stride = stride[1]
remote_ktov_stride = block_stride*remote_moriio_meta.num_blocks
transfer_size_byte = blksize * hn * hs * sz
per_block = 1 if is_mla else 2
@ -2265,8 +2269,11 @@ class MoRIIOConnectorWorker:
w += 1
if not is_mla:
# V
offset_local[w] = sz * (1 * ktov_stride + lb * block_stride)
offset_remote[w] = sz * (1 * ktov_stride + rb * block_stride)
# Handle num_block variations originating from PD (different kv strides)
# TODO: address block_sz differences in heterogeneous TP scenarios
# In MLA, we don't need to consider these two cases.
offset_local[w] = sz * (1 * local_ktov_stride + lb * block_stride)
offset_remote[w] = sz * (1 * remote_ktov_stride + rb * block_stride)
w += 1
merged_l, merged_r, merged_s = self.merge_contiguous_blocks(
@ -2287,11 +2294,11 @@ class MoRIIOConnectorWorker:
return
dp0_engine_id = self.get_engine_name_with_dp(dst_engine_id, 0)
sessions = self._get_built_session(dp0_engine_id)
sessions, remote_moriio_meta = self._get_built_session(dp0_engine_id)
first_layer = list(self.layer_name_to_local_kv_cache_metadata.keys())[0]
offs = self._compute_block_transfer_offsets(
first_layer, local_block_ids, remote_block_ids
first_layer, local_block_ids, remote_block_ids, remote_moriio_meta
)
for layer_name in self.layer_name_to_local_kv_cache_metadata: