mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-03 09:17:03 +08:00
format
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
parent
532d8a7453
commit
e52ba22429
@ -138,6 +138,7 @@ class ROLE(Enum):
|
||||
CONSUMER = "consumer"
|
||||
NOTINIT = "notinit"
|
||||
|
||||
|
||||
class MoRIIOAgentMetadata(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
@ -437,7 +438,9 @@ class MoRIIOWriter:
|
||||
)
|
||||
|
||||
# Get or create sessions
|
||||
sessions, remote_moriio_meta = self.worker._get_built_session(task.dst_engine_id)
|
||||
sessions, remote_moriio_meta = self.worker._get_built_session(
|
||||
task.dst_engine_id
|
||||
)
|
||||
|
||||
# Prepare transfer plan
|
||||
plan = self._prepare_transfer_plan(task, request_info, remote_moriio_meta)
|
||||
@ -449,7 +452,10 @@ class MoRIIOWriter:
|
||||
self._finalize_if_complete(task, request_info)
|
||||
|
||||
def _prepare_transfer_plan(
|
||||
self, task: WriteTask, request_info: RemoteAllocInfo, remote_moriio_meta: MoRIIOAgentMetadata
|
||||
self,
|
||||
task: WriteTask,
|
||||
request_info: RemoteAllocInfo,
|
||||
remote_moriio_meta: MoRIIOAgentMetadata,
|
||||
) -> LayerTransferPlan:
|
||||
"""Prepare the transfer plan for a layer.
|
||||
|
||||
@ -463,7 +469,10 @@ class MoRIIOWriter:
|
||||
# Compute offsets if not cached
|
||||
if request_info.transfer_offset is None:
|
||||
offsets = self.worker._compute_block_transfer_offsets(
|
||||
task.layer_name, task.local_block_ids, request_info.block_ids, remote_moriio_meta
|
||||
task.layer_name,
|
||||
task.local_block_ids,
|
||||
request_info.block_ids,
|
||||
remote_moriio_meta,
|
||||
)
|
||||
request_info.transfer_offset = offsets
|
||||
|
||||
@ -827,8 +836,6 @@ class MoRIIOWrapper:
|
||||
self.paths.clear()
|
||||
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
"""Metadata for a single request."""
|
||||
@ -1603,7 +1610,9 @@ class MoRIIOConnectorWorker:
|
||||
)
|
||||
)
|
||||
self.built_write_session[remote_engine_id] = cur_remote_engine_sessions
|
||||
return self.built_write_session[remote_engine_id], self.remote_moriio_metadata[remote_engine_id]
|
||||
return self.built_write_session[remote_engine_id], self.remote_moriio_metadata[
|
||||
remote_engine_id
|
||||
]
|
||||
|
||||
def _ping(self, zmq_context):
|
||||
http_request_address = f"http://{self.request_address}/v1/completions"
|
||||
@ -1803,7 +1812,7 @@ class MoRIIOConnectorWorker:
|
||||
self.layer_name_to_remote_kv_cache_metadata[expected_engine_id] = (
|
||||
msgpack.loads(buf)
|
||||
)
|
||||
self.remote_moriio_metadata[expected_engine_id]=metadata
|
||||
self.remote_moriio_metadata[expected_engine_id] = metadata
|
||||
setup_agent_time = time.perf_counter()
|
||||
logger.debug(
|
||||
"MoRIIO handshake: add agent took: %s",
|
||||
@ -2251,7 +2260,7 @@ class MoRIIOConnectorWorker:
|
||||
_, blknum, blksize, hn, hs = self.kv_cache_shape
|
||||
local_ktov_stride = stride[0]
|
||||
block_stride = stride[1]
|
||||
remote_ktov_stride = block_stride*remote_moriio_meta.num_blocks
|
||||
remote_ktov_stride = block_stride * remote_moriio_meta.num_blocks
|
||||
|
||||
transfer_size_byte = blksize * hn * hs * sz
|
||||
per_block = 1 if is_mla else 2
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user