mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-03 09:17:03 +08:00
Fix the issue of num_block inconsistency in non-MLA scenarios
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
parent
8a2c136c8f
commit
532d8a7453
@ -138,6 +138,19 @@ class ROLE(Enum):
|
||||
CONSUMER = "consumer"
|
||||
NOTINIT = "notinit"
|
||||
|
||||
class MoRIIOAgentMetadata(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
# required for @cached_property.d
|
||||
dict=True,
|
||||
):
|
||||
engine_id: str
|
||||
agent_metadata: bytes
|
||||
kv_caches_base_addr: list[int]
|
||||
num_blocks: int
|
||||
block_len: int
|
||||
attn_backend_name: str
|
||||
|
||||
|
||||
class RoleManager:
|
||||
"""Manages role state across the connector."""
|
||||
@ -424,10 +437,10 @@ class MoRIIOWriter:
|
||||
)
|
||||
|
||||
# Get or create sessions
|
||||
sessions = self.worker._get_built_session(task.dst_engine_id)
|
||||
sessions, remote_moriio_meta = self.worker._get_built_session(task.dst_engine_id)
|
||||
|
||||
# Prepare transfer plan
|
||||
plan = self._prepare_transfer_plan(task, request_info)
|
||||
plan = self._prepare_transfer_plan(task, request_info, remote_moriio_meta)
|
||||
|
||||
# Execute transfer
|
||||
self._do_layer_write(plan, sessions)
|
||||
@ -436,7 +449,7 @@ class MoRIIOWriter:
|
||||
self._finalize_if_complete(task, request_info)
|
||||
|
||||
def _prepare_transfer_plan(
|
||||
self, task: WriteTask, request_info: RemoteAllocInfo
|
||||
self, task: WriteTask, request_info: RemoteAllocInfo, remote_moriio_meta: MoRIIOAgentMetadata
|
||||
) -> LayerTransferPlan:
|
||||
"""Prepare the transfer plan for a layer.
|
||||
|
||||
@ -450,7 +463,7 @@ class MoRIIOWriter:
|
||||
# Compute offsets if not cached
|
||||
if request_info.transfer_offset is None:
|
||||
offsets = self.worker._compute_block_transfer_offsets(
|
||||
task.layer_name, task.local_block_ids, request_info.block_ids
|
||||
task.layer_name, task.local_block_ids, request_info.block_ids, remote_moriio_meta
|
||||
)
|
||||
request_info.transfer_offset = offsets
|
||||
|
||||
@ -814,18 +827,6 @@ class MoRIIOWrapper:
|
||||
self.paths.clear()
|
||||
|
||||
|
||||
class MoRIIOAgentMetadata(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
# required for @cached_property.d
|
||||
dict=True,
|
||||
):
|
||||
engine_id: str
|
||||
agent_metadata: bytes
|
||||
kv_caches_base_addr: list[int]
|
||||
num_blocks: int
|
||||
block_len: int
|
||||
attn_backend_name: str
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -1461,6 +1462,7 @@ class MoRIIOConnectorWorker:
|
||||
self.layer_name_to_remote_kv_cache_metadata: dict[str, dict[str, list[Any]]] = (
|
||||
dict()
|
||||
)
|
||||
self.remote_moriio_metadata: dict[EngineId, MoRIIOAgentMetadata] = {}
|
||||
self.slot_size_bytes = 0
|
||||
|
||||
self.load_ready_flag: dict[str, bool] = {}
|
||||
@ -1585,10 +1587,10 @@ class MoRIIOConnectorWorker:
|
||||
if remote_engine_id not in self.built_write_session:
|
||||
cur_remote_engine_sessions = []
|
||||
for ln, local_meta in self.layer_name_to_local_kv_cache_metadata.items():
|
||||
unpcaked_local_memory_meta = (
|
||||
unpacked_local_memory_meta = (
|
||||
self.moriio_wrapper.get_unpack_memory_metadata(local_meta[0])
|
||||
)
|
||||
unpcaked_remote_memory_meta = (
|
||||
unpacked_remote_memory_meta = (
|
||||
self.moriio_wrapper.get_unpack_memory_metadata(
|
||||
self.layer_name_to_remote_kv_cache_metadata[remote_engine_id][
|
||||
ln
|
||||
@ -1597,11 +1599,11 @@ class MoRIIOConnectorWorker:
|
||||
)
|
||||
cur_remote_engine_sessions.append(
|
||||
self.moriio_wrapper.build_session(
|
||||
unpcaked_local_memory_meta, unpcaked_remote_memory_meta
|
||||
unpacked_local_memory_meta, unpacked_remote_memory_meta
|
||||
)
|
||||
)
|
||||
self.built_write_session[remote_engine_id] = cur_remote_engine_sessions
|
||||
return self.built_write_session[remote_engine_id]
|
||||
return self.built_write_session[remote_engine_id], self.remote_moriio_metadata[remote_engine_id]
|
||||
|
||||
def _ping(self, zmq_context):
|
||||
http_request_address = f"http://{self.request_address}/v1/completions"
|
||||
@ -1801,7 +1803,7 @@ class MoRIIOConnectorWorker:
|
||||
self.layer_name_to_remote_kv_cache_metadata[expected_engine_id] = (
|
||||
msgpack.loads(buf)
|
||||
)
|
||||
|
||||
self.remote_moriio_metadata[expected_engine_id]=metadata
|
||||
setup_agent_time = time.perf_counter()
|
||||
logger.debug(
|
||||
"MoRIIO handshake: add agent took: %s",
|
||||
@ -2225,6 +2227,7 @@ class MoRIIOConnectorWorker:
|
||||
layer_name: str,
|
||||
local_block_ids: list[int],
|
||||
remote_block_ids: list[int],
|
||||
remote_moriio_meta: MoRIIOAgentMetadata,
|
||||
) -> tuple[list[int], list[int], list[int]]:
|
||||
"""Compute transfer offsets for block data.
|
||||
|
||||
@ -2232,7 +2235,7 @@ class MoRIIOConnectorWorker:
|
||||
layer_name: Name of the layer to transfer
|
||||
local_block_ids: IDs of local blocks
|
||||
remote_block_ids: IDs of remote blocks
|
||||
|
||||
remote_moriio_meta: Metadata of the remote MoRIIO agent
|
||||
Returns:
|
||||
Tuple of (local_offsets, remote_offsets, transfer_sizes)
|
||||
"""
|
||||
@ -2246,8 +2249,9 @@ class MoRIIOConnectorWorker:
|
||||
block_stride = stride[0]
|
||||
else:
|
||||
_, blknum, blksize, hn, hs = self.kv_cache_shape
|
||||
ktov_stride = stride[0]
|
||||
local_ktov_stride = stride[0]
|
||||
block_stride = stride[1]
|
||||
remote_ktov_stride = block_stride*remote_moriio_meta.num_blocks
|
||||
|
||||
transfer_size_byte = blksize * hn * hs * sz
|
||||
per_block = 1 if is_mla else 2
|
||||
@ -2265,8 +2269,11 @@ class MoRIIOConnectorWorker:
|
||||
w += 1
|
||||
if not is_mla:
|
||||
# V
|
||||
offset_local[w] = sz * (1 * ktov_stride + lb * block_stride)
|
||||
offset_remote[w] = sz * (1 * ktov_stride + rb * block_stride)
|
||||
# Handle num_block variations originating from PD (different kv strides)
|
||||
# TODO: address block_sz differences in heterogeneous TP scenarios
|
||||
# In MLA, we don't need to consider these two cases.
|
||||
offset_local[w] = sz * (1 * local_ktov_stride + lb * block_stride)
|
||||
offset_remote[w] = sz * (1 * remote_ktov_stride + rb * block_stride)
|
||||
w += 1
|
||||
|
||||
merged_l, merged_r, merged_s = self.merge_contiguous_blocks(
|
||||
@ -2287,11 +2294,11 @@ class MoRIIOConnectorWorker:
|
||||
return
|
||||
|
||||
dp0_engine_id = self.get_engine_name_with_dp(dst_engine_id, 0)
|
||||
sessions = self._get_built_session(dp0_engine_id)
|
||||
sessions, remote_moriio_meta = self._get_built_session(dp0_engine_id)
|
||||
|
||||
first_layer = list(self.layer_name_to_local_kv_cache_metadata.keys())[0]
|
||||
offs = self._compute_block_transfer_offsets(
|
||||
first_layer, local_block_ids, remote_block_ids
|
||||
first_layer, local_block_ids, remote_block_ids, remote_moriio_meta
|
||||
)
|
||||
|
||||
for layer_name in self.layer_name_to_local_kv_cache_metadata:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user