From 532d8a7453415040f998c4e91fc87841ca1cff89 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Tue, 2 Dec 2025 06:08:03 +0000 Subject: [PATCH] Fix the issue of num_block inconsistency in non-MLA scenarios Signed-off-by: inkcherry --- .../kv_connector/v1/moriio_connector.py | 61 +++++++++++-------- 1 file changed, 34 insertions(+), 27 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 0debaca514d17..7c566bc8c30ba 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -138,6 +138,19 @@ class ROLE(Enum): CONSUMER = "consumer" NOTINIT = "notinit" +class MoRIIOAgentMetadata( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property.d + dict=True, +): + engine_id: str + agent_metadata: bytes + kv_caches_base_addr: list[int] + num_blocks: int + block_len: int + attn_backend_name: str + class RoleManager: """Manages role state across the connector.""" @@ -424,10 +437,10 @@ class MoRIIOWriter: ) # Get or create sessions - sessions = self.worker._get_built_session(task.dst_engine_id) + sessions, remote_moriio_meta = self.worker._get_built_session(task.dst_engine_id) # Prepare transfer plan - plan = self._prepare_transfer_plan(task, request_info) + plan = self._prepare_transfer_plan(task, request_info, remote_moriio_meta) # Execute transfer self._do_layer_write(plan, sessions) @@ -436,7 +449,7 @@ class MoRIIOWriter: self._finalize_if_complete(task, request_info) def _prepare_transfer_plan( - self, task: WriteTask, request_info: RemoteAllocInfo + self, task: WriteTask, request_info: RemoteAllocInfo, remote_moriio_meta: MoRIIOAgentMetadata ) -> LayerTransferPlan: """Prepare the transfer plan for a layer. @@ -450,7 +463,7 @@ class MoRIIOWriter: # Compute offsets if not cached if request_info.transfer_offset is None: offsets = self.worker._compute_block_transfer_offsets( - task.layer_name, task.local_block_ids, request_info.block_ids + task.layer_name, task.local_block_ids, request_info.block_ids, remote_moriio_meta ) request_info.transfer_offset = offsets @@ -814,18 +827,6 @@ class MoRIIOWrapper: self.paths.clear() -class MoRIIOAgentMetadata( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - # required for @cached_property.d - dict=True, -): - engine_id: str - agent_metadata: bytes - kv_caches_base_addr: list[int] - num_blocks: int - block_len: int - attn_backend_name: str @dataclass @@ -1461,6 +1462,7 @@ class MoRIIOConnectorWorker: self.layer_name_to_remote_kv_cache_metadata: dict[str, dict[str, list[Any]]] = ( dict() ) + self.remote_moriio_metadata: dict[EngineId, MoRIIOAgentMetadata] = {} self.slot_size_bytes = 0 self.load_ready_flag: dict[str, bool] = {} @@ -1585,10 +1587,10 @@ class MoRIIOConnectorWorker: if remote_engine_id not in self.built_write_session: cur_remote_engine_sessions = [] for ln, local_meta in self.layer_name_to_local_kv_cache_metadata.items(): - unpcaked_local_memory_meta = ( + unpacked_local_memory_meta = ( self.moriio_wrapper.get_unpack_memory_metadata(local_meta[0]) ) - unpcaked_remote_memory_meta = ( + unpacked_remote_memory_meta = ( self.moriio_wrapper.get_unpack_memory_metadata( self.layer_name_to_remote_kv_cache_metadata[remote_engine_id][ ln @@ -1597,11 +1599,11 @@ class MoRIIOConnectorWorker: ) cur_remote_engine_sessions.append( self.moriio_wrapper.build_session( - unpcaked_local_memory_meta, unpcaked_remote_memory_meta + unpacked_local_memory_meta, unpacked_remote_memory_meta ) ) self.built_write_session[remote_engine_id] = cur_remote_engine_sessions - return self.built_write_session[remote_engine_id] + return self.built_write_session[remote_engine_id], self.remote_moriio_metadata[remote_engine_id] def _ping(self, zmq_context): http_request_address = f"http://{self.request_address}/v1/completions" @@ -1801,7 +1803,7 @@ class MoRIIOConnectorWorker: self.layer_name_to_remote_kv_cache_metadata[expected_engine_id] = ( msgpack.loads(buf) ) - + self.remote_moriio_metadata[expected_engine_id]=metadata setup_agent_time = time.perf_counter() logger.debug( "MoRIIO handshake: add agent took: %s", @@ -2225,6 +2227,7 @@ class MoRIIOConnectorWorker: layer_name: str, local_block_ids: list[int], remote_block_ids: list[int], + remote_moriio_meta: MoRIIOAgentMetadata, ) -> tuple[list[int], list[int], list[int]]: """Compute transfer offsets for block data. @@ -2232,7 +2235,7 @@ class MoRIIOConnectorWorker: layer_name: Name of the layer to transfer local_block_ids: IDs of local blocks remote_block_ids: IDs of remote blocks - + remote_moriio_meta: Metadata of the remote MoRIIO agent Returns: Tuple of (local_offsets, remote_offsets, transfer_sizes) """ @@ -2246,8 +2249,9 @@ class MoRIIOConnectorWorker: block_stride = stride[0] else: _, blknum, blksize, hn, hs = self.kv_cache_shape - ktov_stride = stride[0] + local_ktov_stride = stride[0] block_stride = stride[1] + remote_ktov_stride = block_stride*remote_moriio_meta.num_blocks transfer_size_byte = blksize * hn * hs * sz per_block = 1 if is_mla else 2 @@ -2265,8 +2269,11 @@ class MoRIIOConnectorWorker: w += 1 if not is_mla: # V - offset_local[w] = sz * (1 * ktov_stride + lb * block_stride) - offset_remote[w] = sz * (1 * ktov_stride + rb * block_stride) + # Handle num_block variations originating from PD (different kv strides) + # TODO: address block_sz differences in heterogeneous TP scenarios + # In MLA, we don't need to consider these two cases. + offset_local[w] = sz * (1 * local_ktov_stride + lb * block_stride) + offset_remote[w] = sz * (1 * remote_ktov_stride + rb * block_stride) w += 1 merged_l, merged_r, merged_s = self.merge_contiguous_blocks( @@ -2287,11 +2294,11 @@ class MoRIIOConnectorWorker: return dp0_engine_id = self.get_engine_name_with_dp(dst_engine_id, 0) - sessions = self._get_built_session(dp0_engine_id) + sessions, remote_moriio_meta = self._get_built_session(dp0_engine_id) first_layer = list(self.layer_name_to_local_kv_cache_metadata.keys())[0] offs = self._compute_block_transfer_offsets( - first_layer, local_block_ids, remote_block_ids + first_layer, local_block_ids, remote_block_ids, remote_moriio_meta ) for layer_name in self.layer_name_to_local_kv_cache_metadata: