mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-08 07:17:03 +08:00
hacky
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
684c9b7b6d
commit
7bb3861faf
@ -532,22 +532,29 @@ class NixlConnectorWorker:
|
||||
)
|
||||
return self.tp_size // remote_tp_size
|
||||
else:
|
||||
assert remote_tp_size % self.tp_size == 0, (
|
||||
f"Remote tensor parallel size {remote_tp_size} is not divisible "
|
||||
f"by local tensor parallel size {self.tp_size}."
|
||||
)
|
||||
# P TP > D TP case, return the ratio as negative
|
||||
return -remote_tp_size // self.tp_size
|
||||
|
||||
def is_kv_replicated(self, engine_id: EngineId) -> bool:
|
||||
def is_kv_replicated(self, engine_id: Optional[EngineId] = None, tp_size: Optional[int] = None) -> bool:
|
||||
"""
|
||||
Whether the KV cache is replicated across TP workers due to the
|
||||
number of TP workers being greater than the number of KV heads.
|
||||
"""
|
||||
tp_size = self.remote_tp_size[engine_id]
|
||||
if tp_size is None:
|
||||
assert engine_id is not None
|
||||
tp_size = self.remote_tp_size[engine_id]
|
||||
return tp_size // self.total_num_kv_heads >= 1
|
||||
|
||||
def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool:
|
||||
|
||||
def replicates_kv_cache(self, remote_engine_id: Optional[EngineId] = None, remote_tp_size: Optional[int] = None) -> bool:
|
||||
# MLA is always replicated as the hidden dim can't be split.
|
||||
# TODO docs
|
||||
remote_tp_size = self.remote_tp_size[remote_engine_id]
|
||||
return self.is_mla or self.is_kv_replicated(remote_engine_id) or self.tp_size < remote_tp_size
|
||||
return (
|
||||
self.is_mla
|
||||
or self.is_kv_replicated(remote_engine_id, remote_tp_size)
|
||||
)
|
||||
|
||||
def get_target_remote_ranks(
|
||||
self,
|
||||
@ -564,7 +571,11 @@ class NixlConnectorWorker:
|
||||
else:
|
||||
# P TP > D TP case, D reads from |tp_ratio| remote workers.
|
||||
tp_ratio = -tp_ratio
|
||||
return [self.tp_rank*tp_ratio + i for i in range(tp_ratio)]
|
||||
if self.replicates_kv_cache(remote_engine_id, remote_tp_size):
|
||||
# When cache is replicated on remote, we only need to read
|
||||
# from one remote.
|
||||
return [self.tp_rank*tp_ratio]
|
||||
return [self.tp_rank * tp_ratio + i for i in range(tp_ratio)]
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
||||
if NixlWrapper is None:
|
||||
@ -650,7 +661,9 @@ class NixlConnectorWorker:
|
||||
|
||||
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
|
||||
# rank may pull from multiple remote TP workers.
|
||||
self.kv_caches_base_addr: defaultdict[EngineId, dict[int, list[int]]] = defaultdict(dict)
|
||||
self.kv_caches_base_addr: defaultdict[EngineId, dict[int, list[int]]] = (
|
||||
defaultdict(dict)
|
||||
)
|
||||
|
||||
# Number of NIXL regions. Currently one region per cache
|
||||
# (so 1 per layer for MLA, otherwise 2 per layer)
|
||||
@ -659,9 +672,15 @@ class NixlConnectorWorker:
|
||||
|
||||
# nixl_prepped_dlist_handle.
|
||||
self.src_xfer_side_handle: int = 0
|
||||
# TODO flexible enough to handle different P TP destinations?
|
||||
# tp_ratio->handles
|
||||
# Only poulated during handshake when we read from multiple sources
|
||||
self.src_xfer_side_chunked_handles: dict[int, list[int]] = {}
|
||||
# Map of engine_id -> nixl_prepped_dlist_handle (int)].
|
||||
# TODO do I need tp_Ratio of this?
|
||||
self.dst_xfer_side_handles: defaultdict[EngineId, dict[int, int]] = defaultdict(dict)
|
||||
self.dst_xfer_side_handles: defaultdict[EngineId, dict[int, int]] = defaultdict(
|
||||
dict
|
||||
)
|
||||
|
||||
# Map of engine_id -> num_blocks. All ranks in the same deployment will
|
||||
# have the same number of blocks.
|
||||
@ -771,7 +790,7 @@ class NixlConnectorWorker:
|
||||
# When target instance TP > local TP, we need to perform multiple
|
||||
# handshakes. Do it in a single background job for simplicity.
|
||||
# Regardless, only handshake with the remote TP rank(s) that current
|
||||
# local rank will read from. Note that With homogeneous TP,
|
||||
# local rank will read from. Note that With homogeneous TP,
|
||||
# this happens to be the same single rank_i.
|
||||
p_remote_ranks = self.kv_info.get_target_remote_ranks(
|
||||
remote_tp_size=remote_tp_size
|
||||
@ -791,7 +810,8 @@ class NixlConnectorWorker:
|
||||
metadata = decoder.decode(metadata_bytes)
|
||||
got_metadata_time = time.perf_counter()
|
||||
logger.debug(
|
||||
"NIXL handshake: get metadata took: %s", got_metadata_time - start_time
|
||||
"NIXL handshake: get metadata took: %s",
|
||||
got_metadata_time - start_time,
|
||||
)
|
||||
|
||||
# Ensure engine id matches.
|
||||
@ -812,7 +832,6 @@ class NixlConnectorWorker:
|
||||
setup_agent_time - got_metadata_time,
|
||||
)
|
||||
remote_rank_to_agent_name[remote_rank] = remote_agent_name
|
||||
|
||||
return remote_rank_to_agent_name
|
||||
|
||||
def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> None:
|
||||
@ -983,7 +1002,7 @@ class NixlConnectorWorker:
|
||||
self.num_regions *= 2
|
||||
|
||||
# Register local/src descr for NIXL xfer.
|
||||
blocks_data = []
|
||||
self.src_blocks_data = []
|
||||
for i, base_addr in enumerate(seen_base_addresses):
|
||||
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
|
||||
# NOTE With heter-TP, more blocks are prepared than what are
|
||||
@ -995,7 +1014,7 @@ class NixlConnectorWorker:
|
||||
block_offset = block_id * self.block_len_per_layer[i]
|
||||
addr = base_addr + block_offset
|
||||
# (addr, len, device id)
|
||||
blocks_data.append((addr, kv_block_len, self.tp_rank))
|
||||
self.src_blocks_data.append((addr, kv_block_len, self.tp_rank))
|
||||
|
||||
if self._use_flashinfer:
|
||||
# Separate and interleave K/V regions to maintain the same
|
||||
@ -1006,15 +1025,15 @@ class NixlConnectorWorker:
|
||||
addr = base_addr + block_offset
|
||||
# Register addresses for V cache (K registered first).
|
||||
v_addr = addr + kv_block_len
|
||||
blocks_data.append((v_addr, kv_block_len, self.tp_rank))
|
||||
self.src_blocks_data.append((v_addr, kv_block_len, self.tp_rank))
|
||||
logger.debug(
|
||||
"Created %s blocks for src engine %s and rank %s",
|
||||
len(blocks_data),
|
||||
len(self.src_blocks_data),
|
||||
self.engine_id,
|
||||
self.tp_rank,
|
||||
)
|
||||
|
||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
|
||||
descs = self.nixl_wrapper.get_xfer_descs(self.src_blocks_data, self.nixl_memory_type)
|
||||
# NIXL_INIT_AGENT to be used for preparations of local descs.
|
||||
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
|
||||
"NIXL_INIT_AGENT", descs
|
||||
@ -1125,24 +1144,49 @@ class NixlConnectorWorker:
|
||||
nixl_agent_meta.agent_metadata
|
||||
)
|
||||
|
||||
# Handle tp_size>num_kv_heads: replicate KV cache.
|
||||
replicates_kv_cache = self.kv_info.replicates_kv_cache(engine_id)
|
||||
print("REPLICATES KV CACHE", replicates_kv_cache, "\n")
|
||||
|
||||
# Create dst descs and xfer side handles. TP workers have same #blocks
|
||||
# so we only register once per engine_id.
|
||||
if engine_id not in self.dst_num_blocks:
|
||||
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
|
||||
|
||||
# Keep track of remote agent kv caches base addresses.
|
||||
self.kv_caches_base_addr[engine_id][self.tp_rank] = nixl_agent_meta.kv_caches_base_addr
|
||||
self.kv_caches_base_addr[engine_id][remote_tp_rank] = (
|
||||
nixl_agent_meta.kv_caches_base_addr
|
||||
)
|
||||
self._validate_remote_agent_handshake(nixl_agent_meta, remote_tp_size)
|
||||
|
||||
# Number of D TP workers reading from a single P TP worker. This is
|
||||
# 1 when P and D `--tensor-parallel-size` match. If P TP > D TP,
|
||||
# we don't need to use this for spliting the remote kv cache.
|
||||
# we don't need to use this for splitting the remote kv cache.
|
||||
tp_ratio = self.kv_info.tp_ratio(engine_id)
|
||||
|
||||
# Handle tp_size>num_kv_heads: replicate KV cache.
|
||||
indexes_into_remote = (not self.kv_info.replicates_kv_cache(engine_id) \
|
||||
and tp_ratio < 0)
|
||||
|
||||
# When you realize you're in P TP>DTP you have to split your regions
|
||||
if tp_ratio < 0 and tp_ratio not in self.src_xfer_side_chunked_handles:
|
||||
# TODO use positive tp_ratio value?
|
||||
self.src_xfer_side_chunked_handles[tp_ratio] = []
|
||||
# This is still needed even for MLA
|
||||
# TODO actually only needs one!!
|
||||
# Check if we have a split we can re-use, ie a remote P with same tp_ratio
|
||||
for i in range(-tp_ratio):
|
||||
blocks_data = []
|
||||
for memory_region in self.src_blocks_data:
|
||||
addr, local_block_len, own_tp_rank = memory_region
|
||||
# Computing block len layer by layer allow for different
|
||||
# block sizes per layer
|
||||
# TODO this needs to be an assert when validating
|
||||
# TODO is this the right dim we're splitting on? H?
|
||||
remote_block_len = local_block_len//(-tp_ratio)
|
||||
# Offset
|
||||
addr = addr + i * remote_block_len
|
||||
blocks_data.append((addr, remote_block_len, own_tp_rank)) # TODO same tp_rank?
|
||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
|
||||
handle = self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs)
|
||||
self.src_xfer_side_chunked_handles[tp_ratio].append(handle)
|
||||
|
||||
### Register remote agent memory regions
|
||||
blocks_data = []
|
||||
# With homogeneous TP, D pulls the whole kv cache from corresponding
|
||||
@ -1152,10 +1196,10 @@ class NixlConnectorWorker:
|
||||
|
||||
# Register all remote blocks, but only the corresponding kv heads.
|
||||
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
|
||||
# TODO
|
||||
# TODO workaround
|
||||
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) // 2
|
||||
rank_offset = (
|
||||
self.tp_rank % tp_ratio * kv_block_len if not replicates_kv_cache else 0
|
||||
self.tp_rank % tp_ratio * kv_block_len if indexes_into_remote else 0
|
||||
)
|
||||
for block_id in range(nixl_agent_meta.num_blocks):
|
||||
block_offset = block_id * nixl_agent_meta.block_lens[i]
|
||||
@ -1184,8 +1228,8 @@ class NixlConnectorWorker:
|
||||
|
||||
# Register with NIXL.
|
||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
|
||||
self.dst_xfer_side_handles[engine_id][remote_tp_rank] = self.nixl_wrapper.prep_xfer_dlist(
|
||||
remote_agent_name, descs
|
||||
self.dst_xfer_side_handles[engine_id][remote_tp_rank] = (
|
||||
self.nixl_wrapper.prep_xfer_dlist(remote_agent_name, descs)
|
||||
)
|
||||
|
||||
return remote_agent_name
|
||||
@ -1447,7 +1491,7 @@ class NixlConnectorWorker:
|
||||
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
|
||||
remote_ranks = self.kv_info.get_target_remote_ranks(meta.remote_engine_id)
|
||||
# D may perform multiple reads from different remote ranks.
|
||||
for remote_rank in remote_ranks:
|
||||
for i, remote_rank in enumerate(remote_ranks):
|
||||
logger.debug(
|
||||
"Remote agent %s available, calling _read_blocks"
|
||||
" on remote rank %s for req %s",
|
||||
@ -1455,6 +1499,12 @@ class NixlConnectorWorker:
|
||||
remote_rank,
|
||||
req_id,
|
||||
)
|
||||
# TODO refactor properly and ONLY DO THIS FOR PTP>DTP
|
||||
tp_ratio = self.kv_info.tp_ratio(meta.remote_engine_id)
|
||||
# Get nixl desc handles depending on whether we're reading from
|
||||
# multiple sources or we're reading a chunk of
|
||||
local_xfer_side_handle = self.src_xfer_side_chunked_handles[tp_ratio][i]
|
||||
remote_xfer_side_handle = self.dst_xfer_side_handles[meta.remote_engine_id][remote_rank]
|
||||
# TODO multiread; notifs to all twice?? SPLIT LOCAL BLOCKS!
|
||||
self._read_blocks(
|
||||
request_id=req_id,
|
||||
@ -1462,6 +1512,8 @@ class NixlConnectorWorker:
|
||||
local_block_ids=meta.local_block_ids,
|
||||
remote_block_ids=meta.remote_block_ids,
|
||||
remote_rank=remote_rank,
|
||||
local_xfer_side_handle=local_xfer_side_handle,
|
||||
remote_xfer_side_handle=remote_xfer_side_handle,
|
||||
)
|
||||
|
||||
def _read_blocks(
|
||||
@ -1471,6 +1523,8 @@ class NixlConnectorWorker:
|
||||
dst_engine_id: str,
|
||||
request_id: str,
|
||||
remote_rank: int,
|
||||
local_xfer_side_handle: int,
|
||||
remote_xfer_side_handle: int,
|
||||
):
|
||||
# NOTE(rob): having the staging blocks be on the READER side is
|
||||
# not going to work well (since we will have to call rearrange tensors).
|
||||
@ -1503,8 +1557,8 @@ class NixlConnectorWorker:
|
||||
remote_block_ids = remote_block_ids[-num_local_blocks:]
|
||||
|
||||
# Get side handles.
|
||||
local_xfer_side_handle = self.src_xfer_side_handle
|
||||
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id][remote_rank]
|
||||
# local_xfer_side_handle = self.src_xfer_side_handle
|
||||
# remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id][remote_rank]
|
||||
|
||||
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
|
||||
# corresponding rank. With heterogeneous TP, fixing D>P, the D tp
|
||||
@ -1640,6 +1694,11 @@ class NixlConnectorWorker:
|
||||
if self.src_xfer_side_handle:
|
||||
self.nixl_wrapper.release_dlist_handle(self.src_xfer_side_handle)
|
||||
self.src_xfer_side_handle = 0
|
||||
if self.src_xfer_side_chunked_handles:
|
||||
for handles in self.src_xfer_side_chunked_handles.values():
|
||||
for handle in handles:
|
||||
self.nixl_wrapper.release_dlist_handle(handle)
|
||||
self.src_xfer_side_chunked_handles.clear()
|
||||
for dst_xfer_side_handles in self.dst_xfer_side_handles.values():
|
||||
for dst_xfer_side_handle in dst_xfer_side_handles.values():
|
||||
self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user