mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-09 13:30:49 +08:00
clean up
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
7bb3861faf
commit
9f38fed93c
@ -520,7 +520,9 @@ class NixlConnectorWorker:
|
||||
Calculate the tensor parallel ratio between local and remote TP.
|
||||
We can think of it as the number of local TP workers-per-remote TP
|
||||
workers. Local workers will read from the same remote TP worker in
|
||||
groups of size `tp_ratio`.
|
||||
groups of size `tp_ratio`. If remote tp_size > local tp_size, the
|
||||
ratio is flipped (remote_size/local_size) and the returned value is
|
||||
negative.
|
||||
"""
|
||||
if remote_tp_size is None:
|
||||
assert remote_engine_id is not None
|
||||
@ -539,7 +541,9 @@ class NixlConnectorWorker:
|
||||
# P TP > D TP case, return the ratio as negative
|
||||
return -remote_tp_size // self.tp_size
|
||||
|
||||
def is_kv_replicated(self, engine_id: Optional[EngineId] = None, tp_size: Optional[int] = None) -> bool:
|
||||
def is_kv_replicated(
|
||||
self, engine_id: Optional[EngineId] = None, tp_size: Optional[int] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Whether the KV cache is replicated across TP workers due to the
|
||||
number of TP workers being greater than the number of KV heads.
|
||||
@ -549,11 +553,14 @@ class NixlConnectorWorker:
|
||||
tp_size = self.remote_tp_size[engine_id]
|
||||
return tp_size // self.total_num_kv_heads >= 1
|
||||
|
||||
def replicates_kv_cache(self, remote_engine_id: Optional[EngineId] = None, remote_tp_size: Optional[int] = None) -> bool:
|
||||
def replicates_kv_cache(
|
||||
self,
|
||||
remote_engine_id: Optional[EngineId] = None,
|
||||
remote_tp_size: Optional[int] = None,
|
||||
) -> bool:
|
||||
# MLA is always replicated as the hidden dim can't be split.
|
||||
return (
|
||||
self.is_mla
|
||||
or self.is_kv_replicated(remote_engine_id, remote_tp_size)
|
||||
return self.is_mla or self.is_kv_replicated(
|
||||
remote_engine_id, remote_tp_size
|
||||
)
|
||||
|
||||
def get_target_remote_ranks(
|
||||
@ -563,7 +570,8 @@ class NixlConnectorWorker:
|
||||
) -> list[int]:
|
||||
"""
|
||||
Get the remote TP rank (on P) that the current local TP rank
|
||||
(on D) will read from.
|
||||
(on D) will read from. When remote tp_size > local tp_size, we
|
||||
read from multiple remote ranks.
|
||||
"""
|
||||
tp_ratio = self.tp_ratio(remote_engine_id, remote_tp_size)
|
||||
if tp_ratio > 0:
|
||||
@ -573,8 +581,8 @@ class NixlConnectorWorker:
|
||||
tp_ratio = -tp_ratio
|
||||
if self.replicates_kv_cache(remote_engine_id, remote_tp_size):
|
||||
# When cache is replicated on remote, we only need to read
|
||||
# from one remote.
|
||||
return [self.tp_rank*tp_ratio]
|
||||
# from one remote (they all have the same cache).
|
||||
return [self.tp_rank * tp_ratio]
|
||||
return [self.tp_rank * tp_ratio + i for i in range(tp_ratio)]
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
||||
@ -672,12 +680,10 @@ class NixlConnectorWorker:
|
||||
|
||||
# nixl_prepped_dlist_handle.
|
||||
self.src_xfer_side_handle: int = 0
|
||||
# TODO flexible enough to handle different P TP destinations?
|
||||
# tp_ratio->handles
|
||||
# Only poulated during handshake when we read from multiple sources
|
||||
# Populated dynamically during handshake based on remote configuration.
|
||||
# Keep track of regions at different tp_ratio values. tp_ratio->handles
|
||||
self.src_xfer_side_chunked_handles: dict[int, list[int]] = {}
|
||||
# Map of engine_id -> nixl_prepped_dlist_handle (int)].
|
||||
# TODO do I need tp_Ratio of this?
|
||||
self.dst_xfer_side_handles: defaultdict[EngineId, dict[int, int]] = defaultdict(
|
||||
dict
|
||||
)
|
||||
@ -1033,7 +1039,9 @@ class NixlConnectorWorker:
|
||||
self.tp_rank,
|
||||
)
|
||||
|
||||
descs = self.nixl_wrapper.get_xfer_descs(self.src_blocks_data, self.nixl_memory_type)
|
||||
descs = self.nixl_wrapper.get_xfer_descs(
|
||||
self.src_blocks_data, self.nixl_memory_type
|
||||
)
|
||||
# NIXL_INIT_AGENT to be used for preparations of local descs.
|
||||
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
|
||||
"NIXL_INIT_AGENT", descs
|
||||
@ -1093,10 +1101,12 @@ class NixlConnectorWorker:
|
||||
|
||||
In particular, handle both homogeneous and heterogeneous TP. The former
|
||||
requires local rank_i to read from remote rank_i.
|
||||
The latter, assuming D.world_size > P.world_size, requires that two or
|
||||
more local TP worker share the xfer from a single TP worker.
|
||||
The latter, in the case of D.world_size < P.world_size, requires that a
|
||||
local (D) TP worker reads from multiple remote (P) TP workers.
|
||||
Conversely, assuming D.world_size > P.world_size, two or more local TP
|
||||
workers will read from a single remote TP worker.
|
||||
|
||||
Here's an example (non-MLA case):
|
||||
Here's an example for the last case described above (non-MLA):
|
||||
|
||||
rank_offset p_remote_tp_rank
|
||||
(kv split no)
|
||||
@ -1155,35 +1165,36 @@ class NixlConnectorWorker:
|
||||
)
|
||||
self._validate_remote_agent_handshake(nixl_agent_meta, remote_tp_size)
|
||||
|
||||
# Number of D TP workers reading from a single P TP worker. This is
|
||||
# 1 when P and D `--tensor-parallel-size` match. If P TP > D TP,
|
||||
# we don't need to use this for splitting the remote kv cache.
|
||||
# This is 1 when P and D `--tensor-parallel-size` match. Otherwise,
|
||||
# this is the ratio between the two sizes.
|
||||
tp_ratio = self.kv_info.tp_ratio(engine_id)
|
||||
|
||||
# Handle tp_size>num_kv_heads: replicate KV cache.
|
||||
indexes_into_remote = (not self.kv_info.replicates_kv_cache(engine_id) \
|
||||
and tp_ratio < 0)
|
||||
indexes_into_remote = (
|
||||
not self.kv_info.replicates_kv_cache(engine_id) and tp_ratio < 0
|
||||
)
|
||||
|
||||
# When you realize you're in P TP>DTP you have to split your regions
|
||||
### (Optional) Register local agent memory regions
|
||||
if tp_ratio < 0 and tp_ratio not in self.src_xfer_side_chunked_handles:
|
||||
# TODO use positive tp_ratio value?
|
||||
# Remote tp_size > local tp_size: read from multiple remote ranks.
|
||||
# Logically "split" own regions into |tp_ratio| chunks. Mind that
|
||||
# we only do this once per remote tp_size (replica-friendly).
|
||||
self.src_xfer_side_chunked_handles[tp_ratio] = []
|
||||
# This is still needed even for MLA
|
||||
# TODO actually only needs one!!
|
||||
# Check if we have a split we can re-use, ie a remote P with same tp_ratio
|
||||
for i in range(-tp_ratio):
|
||||
# MLA-optimization: only prepare one region.
|
||||
# NOTE NickLucche: only a chunk of whole cache is used with MLA!
|
||||
tp_ratio_opt = 1 if self.use_mla else -tp_ratio
|
||||
for i in range(tp_ratio_opt):
|
||||
blocks_data = []
|
||||
for memory_region in self.src_blocks_data:
|
||||
addr, local_block_len, own_tp_rank = memory_region
|
||||
# Computing block len layer by layer allow for different
|
||||
# block sizes per layer
|
||||
# TODO this needs to be an assert when validating
|
||||
# TODO is this the right dim we're splitting on? H?
|
||||
remote_block_len = local_block_len//(-tp_ratio)
|
||||
# Offset
|
||||
# Computing block len layer by layer allows for different
|
||||
# block sizes to be used.
|
||||
remote_block_len = local_block_len // (-tp_ratio)
|
||||
addr = addr + i * remote_block_len
|
||||
blocks_data.append((addr, remote_block_len, own_tp_rank)) # TODO same tp_rank?
|
||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
|
||||
blocks_data.append((addr, remote_block_len, own_tp_rank))
|
||||
descs = self.nixl_wrapper.get_xfer_descs(
|
||||
blocks_data, self.nixl_memory_type
|
||||
)
|
||||
handle = self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs)
|
||||
self.src_xfer_side_chunked_handles[tp_ratio].append(handle)
|
||||
|
||||
@ -1196,8 +1207,11 @@ class NixlConnectorWorker:
|
||||
|
||||
# Register all remote blocks, but only the corresponding kv heads.
|
||||
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
|
||||
# TODO workaround
|
||||
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) // 2
|
||||
# Read our whole local region size from remote.
|
||||
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
|
||||
if tp_ratio < 0:
|
||||
# Remote tp is bigger: read a chunk of local region from remote
|
||||
kv_block_len = kv_block_len // (-tp_ratio)
|
||||
rank_offset = (
|
||||
self.tp_rank % tp_ratio * kv_block_len if indexes_into_remote else 0
|
||||
)
|
||||
@ -1262,7 +1276,7 @@ class NixlConnectorWorker:
|
||||
)
|
||||
remote_block_size = remote_block_len // (self.slot_size_per_layer[0])
|
||||
else:
|
||||
if tp_ratio > 1 and self.device_type == "xpu":
|
||||
if tp_ratio != 1 and self.device_type == "xpu":
|
||||
# XPU uses NHD, hence it does not support splitting on H
|
||||
raise ValueError("Heterogeneous TP is not supported on XPU")
|
||||
# When MLA is not used, this is a list of the same block length
|
||||
@ -1270,26 +1284,42 @@ class NixlConnectorWorker:
|
||||
assert block_len == remote_block_len, (
|
||||
"All remote layers must have the same block size"
|
||||
)
|
||||
remote_block_size = remote_block_len // (
|
||||
self.slot_size_per_layer[0] * tp_ratio
|
||||
)
|
||||
|
||||
if tp_ratio > 0:
|
||||
# Remote NHD/H'D*tp_ratio=N -page_size-
|
||||
remote_block_size = remote_block_len // (
|
||||
self.slot_size_per_layer[0] * tp_ratio
|
||||
)
|
||||
# Remote tp is smaller: remote block_len size is bigger
|
||||
assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, (
|
||||
"Remote P worker KV layer cache must be of shape [2, N, "
|
||||
"local_kv_heads*tp_ratio, page_size, head_dim] and same dtype."
|
||||
) # noqa: E501
|
||||
else:
|
||||
# Remote NHD/(H'D/tp_ratio)=N -page_size-
|
||||
remote_block_size = remote_block_len // (
|
||||
self.slot_size_per_layer[0] // (-tp_ratio)
|
||||
)
|
||||
# Remote tp is bigger: remote block_len size is smaller
|
||||
assert remote_block_len == self.block_len_per_layer[0] // (-tp_ratio), (
|
||||
"Remote P worker KV layer cache must be of shape [2, N, "
|
||||
"local_kv_heads/tp_ratio, page_size, head_dim] and same dtype."
|
||||
) # noqa: E501
|
||||
|
||||
if self._use_flashinfer:
|
||||
# With flashinfer, KV are sent in the same message.
|
||||
remote_block_size //= 2
|
||||
# TODO add asserts for P TP > D TP
|
||||
# assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, (
|
||||
# "Remote P worker KV layer cache must be of shape [2, N, "
|
||||
# "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
|
||||
# )
|
||||
|
||||
# assert self.block_size == remote_block_size, (
|
||||
# "Remote P worker with different page/block size is not supported "
|
||||
# f"{self.block_size=}, {remote_block_size=}"
|
||||
# )
|
||||
# We may allow it in the future with logical kvcache manager block_size
|
||||
assert self.block_size == remote_block_size, (
|
||||
"Remote P worker with different page/block size is not supported "
|
||||
f"{self.block_size=}, {remote_block_size=}"
|
||||
)
|
||||
|
||||
# # TP workers have same #blocks.
|
||||
# assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks
|
||||
# assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer)
|
||||
# TP workers (handhshakes with same remote) have same #blocks.
|
||||
assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks
|
||||
# Same number of regions/~layers.
|
||||
assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer)
|
||||
|
||||
def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta):
|
||||
"""copy recved kv from host buffer to device."""
|
||||
@ -1421,7 +1451,7 @@ class NixlConnectorWorker:
|
||||
"""
|
||||
done_req_ids: set[str] = set()
|
||||
for req_id, handles in list(transfers.items()):
|
||||
in_progress = False
|
||||
in_progress = []
|
||||
for handle, _xfer_stime in handles:
|
||||
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
|
||||
if xfer_state == "DONE":
|
||||
@ -1430,13 +1460,16 @@ class NixlConnectorWorker:
|
||||
self.xfer_stats.record_transfer(res)
|
||||
self.nixl_wrapper.release_xfer_handle(handle)
|
||||
elif xfer_state == "PROC":
|
||||
in_progress = True
|
||||
in_progress.append((handle, _xfer_stime))
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError("Transfer failed with state %s", xfer_state)
|
||||
if not in_progress:
|
||||
# Only report request as completed when all transfers are done.
|
||||
done_req_ids.add(req_id)
|
||||
del transfers[req_id]
|
||||
else:
|
||||
transfers[req_id] = in_progress
|
||||
return done_req_ids
|
||||
|
||||
def start_load_kv(self, metadata: NixlConnectorMetadata):
|
||||
@ -1490,7 +1523,8 @@ class NixlConnectorWorker:
|
||||
|
||||
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
|
||||
remote_ranks = self.kv_info.get_target_remote_ranks(meta.remote_engine_id)
|
||||
# D may perform multiple reads from different remote ranks.
|
||||
tp_ratio = self.kv_info.tp_ratio(meta.remote_engine_id)
|
||||
# D may have to perform multiple reads from different remote ranks.
|
||||
for i, remote_rank in enumerate(remote_ranks):
|
||||
logger.debug(
|
||||
"Remote agent %s available, calling _read_blocks"
|
||||
@ -1499,13 +1533,17 @@ class NixlConnectorWorker:
|
||||
remote_rank,
|
||||
req_id,
|
||||
)
|
||||
# TODO refactor properly and ONLY DO THIS FOR PTP>DTP
|
||||
tp_ratio = self.kv_info.tp_ratio(meta.remote_engine_id)
|
||||
# Get nixl desc handles depending on whether we're reading from
|
||||
# multiple sources or we're reading a chunk of
|
||||
local_xfer_side_handle = self.src_xfer_side_chunked_handles[tp_ratio][i]
|
||||
remote_xfer_side_handle = self.dst_xfer_side_handles[meta.remote_engine_id][remote_rank]
|
||||
# TODO multiread; notifs to all twice?? SPLIT LOCAL BLOCKS!
|
||||
if tp_ratio < 0:
|
||||
# Remote tp_size > local tp_size: we must perform multiple
|
||||
# reads. Get the memory chunk onto which we will write to.
|
||||
local_xfer_side_handle = self.src_xfer_side_chunked_handles[tp_ratio][i]
|
||||
else:
|
||||
# Single read from remote, we write to the whole memory region.
|
||||
local_xfer_side_handle = self.src_xfer_side_handle
|
||||
# Destination handle: remote_engine_id -> remote_rank -> handle.
|
||||
remote_xfer_side_handle = self.dst_xfer_side_handles[meta.remote_engine_id][
|
||||
remote_rank
|
||||
]
|
||||
self._read_blocks(
|
||||
request_id=req_id,
|
||||
dst_engine_id=meta.remote_engine_id,
|
||||
@ -1526,6 +1564,10 @@ class NixlConnectorWorker:
|
||||
local_xfer_side_handle: int,
|
||||
remote_xfer_side_handle: int,
|
||||
):
|
||||
"""
|
||||
Post a READ xfer request from a single local worker to a single
|
||||
remote worker.
|
||||
"""
|
||||
# NOTE(rob): having the staging blocks be on the READER side is
|
||||
# not going to work well (since we will have to call rearrange tensors).
|
||||
# after we detect the txn is complete (which means we cannot make the
|
||||
@ -1543,7 +1585,7 @@ class NixlConnectorWorker:
|
||||
notif_id = f"{request_id}:{tp_ratio}".encode()
|
||||
|
||||
# Full prefix cache hit: do not need to read remote blocks,
|
||||
# just notify P worker(s) that we have the blocks we need.
|
||||
# just notify P worker that we have the blocks we need.
|
||||
num_local_blocks = len(local_block_ids)
|
||||
if num_local_blocks == 0:
|
||||
agent_name = self._remote_agents[dst_engine_id][remote_rank]
|
||||
@ -1556,10 +1598,6 @@ class NixlConnectorWorker:
|
||||
if num_local_blocks < num_remote_blocks:
|
||||
remote_block_ids = remote_block_ids[-num_local_blocks:]
|
||||
|
||||
# Get side handles.
|
||||
# local_xfer_side_handle = self.src_xfer_side_handle
|
||||
# remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id][remote_rank]
|
||||
|
||||
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
|
||||
# corresponding rank. With heterogeneous TP, fixing D>P, the D tp
|
||||
# workers will issue xfers to parts of the P worker remote kv caches.
|
||||
@ -1647,7 +1685,6 @@ class NixlConnectorWorker:
|
||||
assert self.num_layers == self.num_regions
|
||||
region_ids = np.arange(layer_idx, layer_idx + 1)
|
||||
|
||||
# TODO can this vary?
|
||||
num_blocks = self.dst_num_blocks[engine_id]
|
||||
|
||||
# Compute the desc ids for each block.
|
||||
@ -1694,11 +1731,10 @@ class NixlConnectorWorker:
|
||||
if self.src_xfer_side_handle:
|
||||
self.nixl_wrapper.release_dlist_handle(self.src_xfer_side_handle)
|
||||
self.src_xfer_side_handle = 0
|
||||
if self.src_xfer_side_chunked_handles:
|
||||
for handles in self.src_xfer_side_chunked_handles.values():
|
||||
for handle in handles:
|
||||
self.nixl_wrapper.release_dlist_handle(handle)
|
||||
self.src_xfer_side_chunked_handles.clear()
|
||||
for handles in self.src_xfer_side_chunked_handles.values():
|
||||
for handle in handles:
|
||||
self.nixl_wrapper.release_dlist_handle(handle)
|
||||
self.src_xfer_side_chunked_handles.clear()
|
||||
for dst_xfer_side_handles in self.dst_xfer_side_handles.values():
|
||||
for dst_xfer_side_handle in dst_xfer_side_handles.values():
|
||||
self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user