mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-17 09:22:23 +08:00
Fix the issue of num_block inconsistency in non-MLA scenarios
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
parent
8a2c136c8f
commit
532d8a7453
@ -138,6 +138,19 @@ class ROLE(Enum):
|
|||||||
CONSUMER = "consumer"
|
CONSUMER = "consumer"
|
||||||
NOTINIT = "notinit"
|
NOTINIT = "notinit"
|
||||||
|
|
||||||
|
class MoRIIOAgentMetadata(
|
||||||
|
msgspec.Struct,
|
||||||
|
omit_defaults=True, # type: ignore[call-arg]
|
||||||
|
# required for @cached_property.d
|
||||||
|
dict=True,
|
||||||
|
):
|
||||||
|
engine_id: str
|
||||||
|
agent_metadata: bytes
|
||||||
|
kv_caches_base_addr: list[int]
|
||||||
|
num_blocks: int
|
||||||
|
block_len: int
|
||||||
|
attn_backend_name: str
|
||||||
|
|
||||||
|
|
||||||
class RoleManager:
|
class RoleManager:
|
||||||
"""Manages role state across the connector."""
|
"""Manages role state across the connector."""
|
||||||
@ -424,10 +437,10 @@ class MoRIIOWriter:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Get or create sessions
|
# Get or create sessions
|
||||||
sessions = self.worker._get_built_session(task.dst_engine_id)
|
sessions, remote_moriio_meta = self.worker._get_built_session(task.dst_engine_id)
|
||||||
|
|
||||||
# Prepare transfer plan
|
# Prepare transfer plan
|
||||||
plan = self._prepare_transfer_plan(task, request_info)
|
plan = self._prepare_transfer_plan(task, request_info, remote_moriio_meta)
|
||||||
|
|
||||||
# Execute transfer
|
# Execute transfer
|
||||||
self._do_layer_write(plan, sessions)
|
self._do_layer_write(plan, sessions)
|
||||||
@ -436,7 +449,7 @@ class MoRIIOWriter:
|
|||||||
self._finalize_if_complete(task, request_info)
|
self._finalize_if_complete(task, request_info)
|
||||||
|
|
||||||
def _prepare_transfer_plan(
|
def _prepare_transfer_plan(
|
||||||
self, task: WriteTask, request_info: RemoteAllocInfo
|
self, task: WriteTask, request_info: RemoteAllocInfo, remote_moriio_meta: MoRIIOAgentMetadata
|
||||||
) -> LayerTransferPlan:
|
) -> LayerTransferPlan:
|
||||||
"""Prepare the transfer plan for a layer.
|
"""Prepare the transfer plan for a layer.
|
||||||
|
|
||||||
@ -450,7 +463,7 @@ class MoRIIOWriter:
|
|||||||
# Compute offsets if not cached
|
# Compute offsets if not cached
|
||||||
if request_info.transfer_offset is None:
|
if request_info.transfer_offset is None:
|
||||||
offsets = self.worker._compute_block_transfer_offsets(
|
offsets = self.worker._compute_block_transfer_offsets(
|
||||||
task.layer_name, task.local_block_ids, request_info.block_ids
|
task.layer_name, task.local_block_ids, request_info.block_ids, remote_moriio_meta
|
||||||
)
|
)
|
||||||
request_info.transfer_offset = offsets
|
request_info.transfer_offset = offsets
|
||||||
|
|
||||||
@ -814,18 +827,6 @@ class MoRIIOWrapper:
|
|||||||
self.paths.clear()
|
self.paths.clear()
|
||||||
|
|
||||||
|
|
||||||
class MoRIIOAgentMetadata(
|
|
||||||
msgspec.Struct,
|
|
||||||
omit_defaults=True, # type: ignore[call-arg]
|
|
||||||
# required for @cached_property.d
|
|
||||||
dict=True,
|
|
||||||
):
|
|
||||||
engine_id: str
|
|
||||||
agent_metadata: bytes
|
|
||||||
kv_caches_base_addr: list[int]
|
|
||||||
num_blocks: int
|
|
||||||
block_len: int
|
|
||||||
attn_backend_name: str
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -1461,6 +1462,7 @@ class MoRIIOConnectorWorker:
|
|||||||
self.layer_name_to_remote_kv_cache_metadata: dict[str, dict[str, list[Any]]] = (
|
self.layer_name_to_remote_kv_cache_metadata: dict[str, dict[str, list[Any]]] = (
|
||||||
dict()
|
dict()
|
||||||
)
|
)
|
||||||
|
self.remote_moriio_metadata: dict[EngineId, MoRIIOAgentMetadata] = {}
|
||||||
self.slot_size_bytes = 0
|
self.slot_size_bytes = 0
|
||||||
|
|
||||||
self.load_ready_flag: dict[str, bool] = {}
|
self.load_ready_flag: dict[str, bool] = {}
|
||||||
@ -1585,10 +1587,10 @@ class MoRIIOConnectorWorker:
|
|||||||
if remote_engine_id not in self.built_write_session:
|
if remote_engine_id not in self.built_write_session:
|
||||||
cur_remote_engine_sessions = []
|
cur_remote_engine_sessions = []
|
||||||
for ln, local_meta in self.layer_name_to_local_kv_cache_metadata.items():
|
for ln, local_meta in self.layer_name_to_local_kv_cache_metadata.items():
|
||||||
unpcaked_local_memory_meta = (
|
unpacked_local_memory_meta = (
|
||||||
self.moriio_wrapper.get_unpack_memory_metadata(local_meta[0])
|
self.moriio_wrapper.get_unpack_memory_metadata(local_meta[0])
|
||||||
)
|
)
|
||||||
unpcaked_remote_memory_meta = (
|
unpacked_remote_memory_meta = (
|
||||||
self.moriio_wrapper.get_unpack_memory_metadata(
|
self.moriio_wrapper.get_unpack_memory_metadata(
|
||||||
self.layer_name_to_remote_kv_cache_metadata[remote_engine_id][
|
self.layer_name_to_remote_kv_cache_metadata[remote_engine_id][
|
||||||
ln
|
ln
|
||||||
@ -1597,11 +1599,11 @@ class MoRIIOConnectorWorker:
|
|||||||
)
|
)
|
||||||
cur_remote_engine_sessions.append(
|
cur_remote_engine_sessions.append(
|
||||||
self.moriio_wrapper.build_session(
|
self.moriio_wrapper.build_session(
|
||||||
unpcaked_local_memory_meta, unpcaked_remote_memory_meta
|
unpacked_local_memory_meta, unpacked_remote_memory_meta
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.built_write_session[remote_engine_id] = cur_remote_engine_sessions
|
self.built_write_session[remote_engine_id] = cur_remote_engine_sessions
|
||||||
return self.built_write_session[remote_engine_id]
|
return self.built_write_session[remote_engine_id], self.remote_moriio_metadata[remote_engine_id]
|
||||||
|
|
||||||
def _ping(self, zmq_context):
|
def _ping(self, zmq_context):
|
||||||
http_request_address = f"http://{self.request_address}/v1/completions"
|
http_request_address = f"http://{self.request_address}/v1/completions"
|
||||||
@ -1801,7 +1803,7 @@ class MoRIIOConnectorWorker:
|
|||||||
self.layer_name_to_remote_kv_cache_metadata[expected_engine_id] = (
|
self.layer_name_to_remote_kv_cache_metadata[expected_engine_id] = (
|
||||||
msgpack.loads(buf)
|
msgpack.loads(buf)
|
||||||
)
|
)
|
||||||
|
self.remote_moriio_metadata[expected_engine_id]=metadata
|
||||||
setup_agent_time = time.perf_counter()
|
setup_agent_time = time.perf_counter()
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"MoRIIO handshake: add agent took: %s",
|
"MoRIIO handshake: add agent took: %s",
|
||||||
@ -2225,6 +2227,7 @@ class MoRIIOConnectorWorker:
|
|||||||
layer_name: str,
|
layer_name: str,
|
||||||
local_block_ids: list[int],
|
local_block_ids: list[int],
|
||||||
remote_block_ids: list[int],
|
remote_block_ids: list[int],
|
||||||
|
remote_moriio_meta: MoRIIOAgentMetadata,
|
||||||
) -> tuple[list[int], list[int], list[int]]:
|
) -> tuple[list[int], list[int], list[int]]:
|
||||||
"""Compute transfer offsets for block data.
|
"""Compute transfer offsets for block data.
|
||||||
|
|
||||||
@ -2232,7 +2235,7 @@ class MoRIIOConnectorWorker:
|
|||||||
layer_name: Name of the layer to transfer
|
layer_name: Name of the layer to transfer
|
||||||
local_block_ids: IDs of local blocks
|
local_block_ids: IDs of local blocks
|
||||||
remote_block_ids: IDs of remote blocks
|
remote_block_ids: IDs of remote blocks
|
||||||
|
remote_moriio_meta: Metadata of the remote MoRIIO agent
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (local_offsets, remote_offsets, transfer_sizes)
|
Tuple of (local_offsets, remote_offsets, transfer_sizes)
|
||||||
"""
|
"""
|
||||||
@ -2246,8 +2249,9 @@ class MoRIIOConnectorWorker:
|
|||||||
block_stride = stride[0]
|
block_stride = stride[0]
|
||||||
else:
|
else:
|
||||||
_, blknum, blksize, hn, hs = self.kv_cache_shape
|
_, blknum, blksize, hn, hs = self.kv_cache_shape
|
||||||
ktov_stride = stride[0]
|
local_ktov_stride = stride[0]
|
||||||
block_stride = stride[1]
|
block_stride = stride[1]
|
||||||
|
remote_ktov_stride = block_stride*remote_moriio_meta.num_blocks
|
||||||
|
|
||||||
transfer_size_byte = blksize * hn * hs * sz
|
transfer_size_byte = blksize * hn * hs * sz
|
||||||
per_block = 1 if is_mla else 2
|
per_block = 1 if is_mla else 2
|
||||||
@ -2265,8 +2269,11 @@ class MoRIIOConnectorWorker:
|
|||||||
w += 1
|
w += 1
|
||||||
if not is_mla:
|
if not is_mla:
|
||||||
# V
|
# V
|
||||||
offset_local[w] = sz * (1 * ktov_stride + lb * block_stride)
|
# Handle num_block variations originating from PD (different kv strides)
|
||||||
offset_remote[w] = sz * (1 * ktov_stride + rb * block_stride)
|
# TODO: address block_sz differences in heterogeneous TP scenarios
|
||||||
|
# In MLA, we don't need to consider these two cases.
|
||||||
|
offset_local[w] = sz * (1 * local_ktov_stride + lb * block_stride)
|
||||||
|
offset_remote[w] = sz * (1 * remote_ktov_stride + rb * block_stride)
|
||||||
w += 1
|
w += 1
|
||||||
|
|
||||||
merged_l, merged_r, merged_s = self.merge_contiguous_blocks(
|
merged_l, merged_r, merged_s = self.merge_contiguous_blocks(
|
||||||
@ -2287,11 +2294,11 @@ class MoRIIOConnectorWorker:
|
|||||||
return
|
return
|
||||||
|
|
||||||
dp0_engine_id = self.get_engine_name_with_dp(dst_engine_id, 0)
|
dp0_engine_id = self.get_engine_name_with_dp(dst_engine_id, 0)
|
||||||
sessions = self._get_built_session(dp0_engine_id)
|
sessions, remote_moriio_meta = self._get_built_session(dp0_engine_id)
|
||||||
|
|
||||||
first_layer = list(self.layer_name_to_local_kv_cache_metadata.keys())[0]
|
first_layer = list(self.layer_name_to_local_kv_cache_metadata.keys())[0]
|
||||||
offs = self._compute_block_transfer_offsets(
|
offs = self._compute_block_transfer_offsets(
|
||||||
first_layer, local_block_ids, remote_block_ids
|
first_layer, local_block_ids, remote_block_ids, remote_moriio_meta
|
||||||
)
|
)
|
||||||
|
|
||||||
for layer_name in self.layer_name_to_local_kv_cache_metadata:
|
for layer_name in self.layer_name_to_local_kv_cache_metadata:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user