Fix the issue of num_block inconsistency in non-MLA scenarios

Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
inkcherry 2025-12-02 06:08:03 +00:00
parent 8a2c136c8f
commit 532d8a7453

View File

@ -138,6 +138,19 @@ class ROLE(Enum):
CONSUMER = "consumer" CONSUMER = "consumer"
NOTINIT = "notinit" NOTINIT = "notinit"
class MoRIIOAgentMetadata(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
# required for @cached_property.d
dict=True,
):
engine_id: str
agent_metadata: bytes
kv_caches_base_addr: list[int]
num_blocks: int
block_len: int
attn_backend_name: str
class RoleManager: class RoleManager:
"""Manages role state across the connector.""" """Manages role state across the connector."""
@ -424,10 +437,10 @@ class MoRIIOWriter:
) )
# Get or create sessions # Get or create sessions
sessions = self.worker._get_built_session(task.dst_engine_id) sessions, remote_moriio_meta = self.worker._get_built_session(task.dst_engine_id)
# Prepare transfer plan # Prepare transfer plan
plan = self._prepare_transfer_plan(task, request_info) plan = self._prepare_transfer_plan(task, request_info, remote_moriio_meta)
# Execute transfer # Execute transfer
self._do_layer_write(plan, sessions) self._do_layer_write(plan, sessions)
@ -436,7 +449,7 @@ class MoRIIOWriter:
self._finalize_if_complete(task, request_info) self._finalize_if_complete(task, request_info)
def _prepare_transfer_plan( def _prepare_transfer_plan(
self, task: WriteTask, request_info: RemoteAllocInfo self, task: WriteTask, request_info: RemoteAllocInfo, remote_moriio_meta: MoRIIOAgentMetadata
) -> LayerTransferPlan: ) -> LayerTransferPlan:
"""Prepare the transfer plan for a layer. """Prepare the transfer plan for a layer.
@ -450,7 +463,7 @@ class MoRIIOWriter:
# Compute offsets if not cached # Compute offsets if not cached
if request_info.transfer_offset is None: if request_info.transfer_offset is None:
offsets = self.worker._compute_block_transfer_offsets( offsets = self.worker._compute_block_transfer_offsets(
task.layer_name, task.local_block_ids, request_info.block_ids task.layer_name, task.local_block_ids, request_info.block_ids, remote_moriio_meta
) )
request_info.transfer_offset = offsets request_info.transfer_offset = offsets
@ -814,18 +827,6 @@ class MoRIIOWrapper:
self.paths.clear() self.paths.clear()
class MoRIIOAgentMetadata(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
# required for @cached_property.d
dict=True,
):
engine_id: str
agent_metadata: bytes
kv_caches_base_addr: list[int]
num_blocks: int
block_len: int
attn_backend_name: str
@dataclass @dataclass
@ -1461,6 +1462,7 @@ class MoRIIOConnectorWorker:
self.layer_name_to_remote_kv_cache_metadata: dict[str, dict[str, list[Any]]] = ( self.layer_name_to_remote_kv_cache_metadata: dict[str, dict[str, list[Any]]] = (
dict() dict()
) )
self.remote_moriio_metadata: dict[EngineId, MoRIIOAgentMetadata] = {}
self.slot_size_bytes = 0 self.slot_size_bytes = 0
self.load_ready_flag: dict[str, bool] = {} self.load_ready_flag: dict[str, bool] = {}
@ -1585,10 +1587,10 @@ class MoRIIOConnectorWorker:
if remote_engine_id not in self.built_write_session: if remote_engine_id not in self.built_write_session:
cur_remote_engine_sessions = [] cur_remote_engine_sessions = []
for ln, local_meta in self.layer_name_to_local_kv_cache_metadata.items(): for ln, local_meta in self.layer_name_to_local_kv_cache_metadata.items():
unpcaked_local_memory_meta = ( unpacked_local_memory_meta = (
self.moriio_wrapper.get_unpack_memory_metadata(local_meta[0]) self.moriio_wrapper.get_unpack_memory_metadata(local_meta[0])
) )
unpcaked_remote_memory_meta = ( unpacked_remote_memory_meta = (
self.moriio_wrapper.get_unpack_memory_metadata( self.moriio_wrapper.get_unpack_memory_metadata(
self.layer_name_to_remote_kv_cache_metadata[remote_engine_id][ self.layer_name_to_remote_kv_cache_metadata[remote_engine_id][
ln ln
@ -1597,11 +1599,11 @@ class MoRIIOConnectorWorker:
) )
cur_remote_engine_sessions.append( cur_remote_engine_sessions.append(
self.moriio_wrapper.build_session( self.moriio_wrapper.build_session(
unpcaked_local_memory_meta, unpcaked_remote_memory_meta unpacked_local_memory_meta, unpacked_remote_memory_meta
) )
) )
self.built_write_session[remote_engine_id] = cur_remote_engine_sessions self.built_write_session[remote_engine_id] = cur_remote_engine_sessions
return self.built_write_session[remote_engine_id] return self.built_write_session[remote_engine_id], self.remote_moriio_metadata[remote_engine_id]
def _ping(self, zmq_context): def _ping(self, zmq_context):
http_request_address = f"http://{self.request_address}/v1/completions" http_request_address = f"http://{self.request_address}/v1/completions"
@ -1801,7 +1803,7 @@ class MoRIIOConnectorWorker:
self.layer_name_to_remote_kv_cache_metadata[expected_engine_id] = ( self.layer_name_to_remote_kv_cache_metadata[expected_engine_id] = (
msgpack.loads(buf) msgpack.loads(buf)
) )
self.remote_moriio_metadata[expected_engine_id]=metadata
setup_agent_time = time.perf_counter() setup_agent_time = time.perf_counter()
logger.debug( logger.debug(
"MoRIIO handshake: add agent took: %s", "MoRIIO handshake: add agent took: %s",
@ -2225,6 +2227,7 @@ class MoRIIOConnectorWorker:
layer_name: str, layer_name: str,
local_block_ids: list[int], local_block_ids: list[int],
remote_block_ids: list[int], remote_block_ids: list[int],
remote_moriio_meta: MoRIIOAgentMetadata,
) -> tuple[list[int], list[int], list[int]]: ) -> tuple[list[int], list[int], list[int]]:
"""Compute transfer offsets for block data. """Compute transfer offsets for block data.
@ -2232,7 +2235,7 @@ class MoRIIOConnectorWorker:
layer_name: Name of the layer to transfer layer_name: Name of the layer to transfer
local_block_ids: IDs of local blocks local_block_ids: IDs of local blocks
remote_block_ids: IDs of remote blocks remote_block_ids: IDs of remote blocks
remote_moriio_meta: Metadata of the remote MoRIIO agent
Returns: Returns:
Tuple of (local_offsets, remote_offsets, transfer_sizes) Tuple of (local_offsets, remote_offsets, transfer_sizes)
""" """
@ -2246,8 +2249,9 @@ class MoRIIOConnectorWorker:
block_stride = stride[0] block_stride = stride[0]
else: else:
_, blknum, blksize, hn, hs = self.kv_cache_shape _, blknum, blksize, hn, hs = self.kv_cache_shape
ktov_stride = stride[0] local_ktov_stride = stride[0]
block_stride = stride[1] block_stride = stride[1]
remote_ktov_stride = block_stride*remote_moriio_meta.num_blocks
transfer_size_byte = blksize * hn * hs * sz transfer_size_byte = blksize * hn * hs * sz
per_block = 1 if is_mla else 2 per_block = 1 if is_mla else 2
@ -2265,8 +2269,11 @@ class MoRIIOConnectorWorker:
w += 1 w += 1
if not is_mla: if not is_mla:
# V # V
offset_local[w] = sz * (1 * ktov_stride + lb * block_stride) # Handle num_block variations originating from PD (different kv strides)
offset_remote[w] = sz * (1 * ktov_stride + rb * block_stride) # TODO: address block_sz differences in heterogeneous TP scenarios
# In MLA, we don't need to consider these two cases.
offset_local[w] = sz * (1 * local_ktov_stride + lb * block_stride)
offset_remote[w] = sz * (1 * remote_ktov_stride + rb * block_stride)
w += 1 w += 1
merged_l, merged_r, merged_s = self.merge_contiguous_blocks( merged_l, merged_r, merged_s = self.merge_contiguous_blocks(
@ -2287,11 +2294,11 @@ class MoRIIOConnectorWorker:
return return
dp0_engine_id = self.get_engine_name_with_dp(dst_engine_id, 0) dp0_engine_id = self.get_engine_name_with_dp(dst_engine_id, 0)
sessions = self._get_built_session(dp0_engine_id) sessions, remote_moriio_meta = self._get_built_session(dp0_engine_id)
first_layer = list(self.layer_name_to_local_kv_cache_metadata.keys())[0] first_layer = list(self.layer_name_to_local_kv_cache_metadata.keys())[0]
offs = self._compute_block_transfer_offsets( offs = self._compute_block_transfer_offsets(
first_layer, local_block_ids, remote_block_ids first_layer, local_block_ids, remote_block_ids, remote_moriio_meta
) )
for layer_name in self.layer_name_to_local_kv_cache_metadata: for layer_name in self.layer_name_to_local_kv_cache_metadata: