diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 7c566bc8c30ba..e0e60683960ae 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -138,6 +138,7 @@ class ROLE(Enum): CONSUMER = "consumer" NOTINIT = "notinit" + class MoRIIOAgentMetadata( msgspec.Struct, omit_defaults=True, # type: ignore[call-arg] @@ -437,7 +438,9 @@ class MoRIIOWriter: ) # Get or create sessions - sessions, remote_moriio_meta = self.worker._get_built_session(task.dst_engine_id) + sessions, remote_moriio_meta = self.worker._get_built_session( + task.dst_engine_id + ) # Prepare transfer plan plan = self._prepare_transfer_plan(task, request_info, remote_moriio_meta) @@ -449,7 +452,10 @@ class MoRIIOWriter: self._finalize_if_complete(task, request_info) def _prepare_transfer_plan( - self, task: WriteTask, request_info: RemoteAllocInfo, remote_moriio_meta: MoRIIOAgentMetadata + self, + task: WriteTask, + request_info: RemoteAllocInfo, + remote_moriio_meta: MoRIIOAgentMetadata, ) -> LayerTransferPlan: """Prepare the transfer plan for a layer. @@ -463,7 +469,10 @@ class MoRIIOWriter: # Compute offsets if not cached if request_info.transfer_offset is None: offsets = self.worker._compute_block_transfer_offsets( - task.layer_name, task.local_block_ids, request_info.block_ids, remote_moriio_meta + task.layer_name, + task.local_block_ids, + request_info.block_ids, + remote_moriio_meta, ) request_info.transfer_offset = offsets @@ -827,8 +836,6 @@ class MoRIIOWrapper: self.paths.clear() - - @dataclass class ReqMeta: """Metadata for a single request.""" @@ -1603,7 +1610,9 @@ class MoRIIOConnectorWorker: ) ) self.built_write_session[remote_engine_id] = cur_remote_engine_sessions - return self.built_write_session[remote_engine_id], self.remote_moriio_metadata[remote_engine_id] + return self.built_write_session[remote_engine_id], self.remote_moriio_metadata[ + remote_engine_id + ] def _ping(self, zmq_context): http_request_address = f"http://{self.request_address}/v1/completions" @@ -1803,7 +1812,7 @@ class MoRIIOConnectorWorker: self.layer_name_to_remote_kv_cache_metadata[expected_engine_id] = ( msgpack.loads(buf) ) - self.remote_moriio_metadata[expected_engine_id]=metadata + self.remote_moriio_metadata[expected_engine_id] = metadata setup_agent_time = time.perf_counter() logger.debug( "MoRIIO handshake: add agent took: %s", @@ -2251,7 +2260,7 @@ class MoRIIOConnectorWorker: _, blknum, blksize, hn, hs = self.kv_cache_shape local_ktov_stride = stride[0] block_stride = stride[1] - remote_ktov_stride = block_stride*remote_moriio_meta.num_blocks + remote_ktov_stride = block_stride * remote_moriio_meta.num_blocks transfer_size_byte = blksize * hn * hs * sz per_block = 1 if is_mla else 2