mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:54:56 +08:00
[NIXL] heterogeneous block_size support (#26759)
Signed-off-by: Chendi Xue <chendi.xue@intel.com> Signed-off-by: Chendi.Xue <chendi.xue@intel.com> Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com>
This commit is contained in:
parent
363aaeef0f
commit
c9e665852a
@ -49,6 +49,8 @@ NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-1} # Default to 1
|
||||
PREFILLER_TP_SIZE=${PREFILLER_TP_SIZE:-1}
|
||||
DECODER_TP_SIZE=${DECODER_TP_SIZE:-1}
|
||||
GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.2}
|
||||
PREFILL_BLOCK_SIZE=${PREFILL_BLOCK_SIZE:-16}
|
||||
DECODE_BLOCK_SIZE=${DECODE_BLOCK_SIZE:-16}
|
||||
|
||||
# Find the git repository root directory
|
||||
GIT_ROOT=$(git rev-parse --show-toplevel)
|
||||
@ -136,6 +138,7 @@ run_tests_for_model() {
|
||||
vllm serve $model_name \
|
||||
--port $PORT \
|
||||
--enforce-eager \
|
||||
--block-size ${PREFILL_BLOCK_SIZE} \
|
||||
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
|
||||
--tensor-parallel-size $PREFILLER_TP_SIZE \
|
||||
--kv-transfer-config '$KV_CONFIG'"
|
||||
@ -177,6 +180,7 @@ run_tests_for_model() {
|
||||
vllm serve $model_name \
|
||||
--port $PORT \
|
||||
--enforce-eager \
|
||||
--block-size ${DECODE_BLOCK_SIZE} \
|
||||
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
|
||||
--kv-transfer-config '$KV_CONFIG'"
|
||||
|
||||
|
||||
@ -407,6 +407,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||
# `self.kv_cache_layout` is only forced to HND when vllm engine
|
||||
# is started. We mock HND here.
|
||||
kv_cache_layout="HND",
|
||||
block_size=self.block_size,
|
||||
),
|
||||
remote_tp_size=remote_tp_size,
|
||||
)
|
||||
@ -652,6 +653,7 @@ class TestNixlHandshake:
|
||||
block_lens=worker.block_len_per_layer,
|
||||
attn_backend_name=worker.backend_name,
|
||||
kv_cache_layout=mismatched_layout,
|
||||
block_size=worker.block_size,
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
@ -706,6 +708,7 @@ class TestNixlHandshake:
|
||||
block_lens=[i * 2 for i in worker.block_len_per_layer],
|
||||
attn_backend_name=worker.backend_name,
|
||||
kv_cache_layout="HND",
|
||||
block_size=worker.block_size,
|
||||
)
|
||||
|
||||
# We don't check layout for homogeneous TP and MLA for now, as the
|
||||
|
||||
@ -108,6 +108,7 @@ class NixlAgentMetadata(KVConnectorHandshakeMetadata):
|
||||
block_lens: list[int]
|
||||
attn_backend_name: str
|
||||
kv_cache_layout: str
|
||||
block_size: int
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -709,6 +710,9 @@ class NixlConnectorWorker:
|
||||
self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first
|
||||
)
|
||||
|
||||
block_size: int
|
||||
remote_block_size: dict[EngineId, int]
|
||||
|
||||
def tp_ratio(
|
||||
self,
|
||||
remote_tp_size: int,
|
||||
@ -725,6 +729,19 @@ class NixlConnectorWorker:
|
||||
)
|
||||
return self.tp_size // remote_tp_size
|
||||
|
||||
def block_size_ratio(
|
||||
self,
|
||||
remote_block_size: int,
|
||||
) -> float:
|
||||
"""
|
||||
Calculate the block size ratio between local and remote TP.
|
||||
"""
|
||||
assert self.block_size % remote_block_size == 0, (
|
||||
f"Local block size {self.block_size} is not divisible "
|
||||
f"by remote block size {remote_block_size} or vice versa."
|
||||
)
|
||||
return self.block_size // remote_block_size
|
||||
|
||||
def tp_ratio_from_engine_id(
|
||||
self,
|
||||
remote_engine_id: EngineId,
|
||||
@ -732,6 +749,13 @@ class NixlConnectorWorker:
|
||||
remote_tp_size = self.remote_tp_size[remote_engine_id]
|
||||
return self.tp_ratio(remote_tp_size)
|
||||
|
||||
def block_size_ratio_from_engine_id(
|
||||
self,
|
||||
remote_engine_id: EngineId,
|
||||
) -> float:
|
||||
remote_block_size = self.remote_block_size[remote_engine_id]
|
||||
return self.block_size_ratio(remote_block_size)
|
||||
|
||||
def is_kv_replicated(self, engine_id: EngineId) -> bool:
|
||||
"""
|
||||
Whether the KV cache is replicated across TP workers due to the
|
||||
@ -866,6 +890,7 @@ class NixlConnectorWorker:
|
||||
|
||||
# nixl_prepped_dlist_handle.
|
||||
self.src_xfer_side_handle: int = 0
|
||||
self.src_xfer_side_handles: dict[int, int] = {}
|
||||
# Map of engine_id -> nixl_prepped_dlist_handle (int)].
|
||||
self.dst_xfer_side_handles: dict[EngineId, int] = {}
|
||||
|
||||
@ -925,6 +950,7 @@ class NixlConnectorWorker:
|
||||
logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
|
||||
|
||||
self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
|
||||
self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size}
|
||||
# With heterogeneous TP, P must wait for all assigned D TP workers to
|
||||
# finish reading before safely freeing the blocks.
|
||||
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
|
||||
@ -936,6 +962,8 @@ class NixlConnectorWorker:
|
||||
remote_tp_size=self._tp_size, # shared state
|
||||
is_mla=self.use_mla,
|
||||
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
|
||||
block_size=self.block_size,
|
||||
remote_block_size=self._block_size,
|
||||
attn_backend=backend,
|
||||
)
|
||||
self._use_pallas = self.kv_topo._use_pallas
|
||||
@ -987,9 +1015,13 @@ class NixlConnectorWorker:
|
||||
)
|
||||
|
||||
# Register Remote agent.
|
||||
assert metadata.block_size <= self.block_size, (
|
||||
"nP > nD is not supported yet."
|
||||
)
|
||||
remote_agent_name = self.add_remote_agent(
|
||||
metadata, p_remote_rank, remote_tp_size
|
||||
)
|
||||
|
||||
setup_agent_time = time.perf_counter()
|
||||
logger.debug(
|
||||
"NIXL handshake: add agent took: %s",
|
||||
@ -1217,43 +1249,10 @@ class NixlConnectorWorker:
|
||||
self.num_regions *= 2
|
||||
|
||||
# Register local/src descr for NIXL xfer.
|
||||
blocks_data = []
|
||||
for i, base_addr in enumerate(seen_base_addresses):
|
||||
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
|
||||
# NOTE With heter-TP, more blocks are prepared than what are
|
||||
# needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
|
||||
# could create fewer, but then _get_block_descs_ids needs to
|
||||
# select agent_meta.num_blocks instead of self.num_blocks for
|
||||
# local descr, and that makes handling regular flow less clean.
|
||||
for block_id in range(self.num_blocks):
|
||||
block_offset = block_id * self.block_len_per_layer[i]
|
||||
addr = base_addr + block_offset
|
||||
# (addr, len, device id)
|
||||
blocks_data.append((addr, kv_block_len, self.device_id))
|
||||
self.seen_base_addresses = seen_base_addresses
|
||||
self.src_xfer_side_handle = self.register_local_xfer_handler(self.block_size)
|
||||
|
||||
if self.kv_topo.is_kv_layout_blocks_first:
|
||||
# Separate and interleave K/V regions to maintain the same
|
||||
# descs ordering. This is needed for selecting contiguous heads
|
||||
# when split across TP ranks.
|
||||
for block_id in range(self.num_blocks):
|
||||
block_offset = block_id * self.block_len_per_layer[i]
|
||||
addr = base_addr + block_offset
|
||||
# Register addresses for V cache (K registered first).
|
||||
v_addr = addr + kv_block_len
|
||||
blocks_data.append((v_addr, kv_block_len, self.device_id))
|
||||
logger.debug(
|
||||
"Created %s blocks for src engine %s and rank %s on device id %s",
|
||||
len(blocks_data),
|
||||
self.engine_id,
|
||||
self.tp_rank,
|
||||
self.device_id,
|
||||
)
|
||||
|
||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
|
||||
# NIXL_INIT_AGENT to be used for preparations of local descs.
|
||||
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
|
||||
"NIXL_INIT_AGENT", descs
|
||||
)
|
||||
self.src_xfer_side_handles[self.block_size] = self.src_xfer_side_handle
|
||||
|
||||
# TODO(mgoin): Hybrid memory allocator is currently disabled for
|
||||
# models with local attention (Llama 4). Can remove this once enabled.
|
||||
@ -1289,8 +1288,62 @@ class NixlConnectorWorker:
|
||||
kv_cache_layout=self.kv_cache_layout
|
||||
if not self.use_host_buffer
|
||||
else self.host_buffer_kv_cache_layout,
|
||||
block_size=self.block_size,
|
||||
)
|
||||
|
||||
def register_local_xfer_handler(
|
||||
self,
|
||||
block_size: int,
|
||||
) -> int:
|
||||
"""
|
||||
Function used for register local xfer handler with local block_size or
|
||||
Remote block_size.
|
||||
|
||||
When local block_size is same as remote block_size, we use local block_size
|
||||
to register local_xfer_handler during init.
|
||||
|
||||
When remote block size is less than local block size, we need to use
|
||||
register another local_xfer_handler using remote block len to ensure
|
||||
data copy correctness.
|
||||
"""
|
||||
block_size_ratio = self.block_size // block_size
|
||||
blocks_data = []
|
||||
for i, base_addr in enumerate(self.seen_base_addresses):
|
||||
# The new block_len is using prefill block_len;
|
||||
# and num_blocks is multiple with N
|
||||
kv_block_len = (
|
||||
self.get_backend_aware_kv_block_len(layer_idx=i) // block_size_ratio
|
||||
)
|
||||
block_len_per_layer = self.block_len_per_layer[i] // block_size_ratio
|
||||
num_blocks = self.num_blocks * block_size_ratio
|
||||
for block_id in range(num_blocks):
|
||||
block_offset = block_id * block_len_per_layer
|
||||
addr = base_addr + block_offset
|
||||
# (addr, len, device id)
|
||||
blocks_data.append((addr, kv_block_len, self.device_id))
|
||||
|
||||
if self.kv_topo.is_kv_layout_blocks_first:
|
||||
# Separate and interleave K/V regions to maintain the same
|
||||
# descs ordering. This is needed for selecting contiguous heads
|
||||
# when split across TP ranks.
|
||||
for block_id in range(num_blocks):
|
||||
block_offset = block_id * block_len_per_layer
|
||||
addr = base_addr + block_offset
|
||||
# Register addresses for V cache (K registered first).
|
||||
v_addr = addr + kv_block_len
|
||||
blocks_data.append((v_addr, kv_block_len, self.device_id))
|
||||
logger.debug(
|
||||
"Created %s blocks for src engine %s and rank %s on device id %s",
|
||||
len(blocks_data),
|
||||
self.engine_id,
|
||||
self.tp_rank,
|
||||
self.device_id,
|
||||
)
|
||||
|
||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
|
||||
# NIXL_INIT_AGENT to be used for preparations of local descs.
|
||||
return self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs)
|
||||
|
||||
def add_remote_agent(
|
||||
self,
|
||||
nixl_agent_meta: NixlAgentMetadata,
|
||||
@ -1349,6 +1402,8 @@ class NixlConnectorWorker:
|
||||
### Register remote agent metadata
|
||||
if engine_id not in self._tp_size:
|
||||
self._tp_size[engine_id] = remote_tp_size
|
||||
if engine_id not in self._block_size:
|
||||
self._block_size[engine_id] = nixl_agent_meta.block_size
|
||||
|
||||
remote_agent_name = self.nixl_wrapper.add_remote_agent(
|
||||
nixl_agent_meta.agent_metadata
|
||||
@ -1359,6 +1414,13 @@ class NixlConnectorWorker:
|
||||
|
||||
# Create dst descs and xfer side handles. TP workers have same #blocks
|
||||
# so we only register once per engine_id.
|
||||
# Example:
|
||||
# block_size_ratio > 1:
|
||||
# remote: | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|
|
||||
# local origin:| 0| 1| 8| 12|
|
||||
# local mapped:| 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|13|14|15|
|
||||
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(engine_id)
|
||||
|
||||
if engine_id not in self.dst_num_blocks:
|
||||
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
|
||||
|
||||
@ -1381,8 +1443,14 @@ class NixlConnectorWorker:
|
||||
# Register all remote blocks, but only the corresponding kv heads.
|
||||
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
|
||||
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
|
||||
remote_kv_block_len = kv_block_len // block_size_ratio
|
||||
if block_size_ratio > 1:
|
||||
# using remote kv_block_len as transfer unit
|
||||
kv_block_len = remote_kv_block_len
|
||||
rank_offset = (
|
||||
self.tp_rank % tp_ratio * kv_block_len if not replicates_kv_cache else 0
|
||||
self.tp_rank % tp_ratio * remote_kv_block_len
|
||||
if not replicates_kv_cache
|
||||
else 0
|
||||
)
|
||||
for block_id in range(nixl_agent_meta.num_blocks):
|
||||
block_offset = block_id * nixl_agent_meta.block_lens[i]
|
||||
@ -1417,6 +1485,13 @@ class NixlConnectorWorker:
|
||||
remote_agent_name, descs
|
||||
)
|
||||
|
||||
if block_size_ratio > 1:
|
||||
# when prefill with smaller block_size, we need to init a
|
||||
# new handler with same block_len to match
|
||||
self.src_xfer_side_handles[nixl_agent_meta.block_size] = (
|
||||
self.register_local_xfer_handler(nixl_agent_meta.block_size)
|
||||
)
|
||||
|
||||
return remote_agent_name
|
||||
|
||||
def _validate_remote_agent_handshake(
|
||||
@ -1433,6 +1508,9 @@ class NixlConnectorWorker:
|
||||
assert nixl_agent_meta.attn_backend_name == self.backend_name
|
||||
|
||||
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(remote_engine_id)
|
||||
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(
|
||||
remote_engine_id
|
||||
)
|
||||
assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP"
|
||||
assert not self._use_pallas or tp_ratio == 1, (
|
||||
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
|
||||
@ -1463,33 +1541,26 @@ class NixlConnectorWorker:
|
||||
remote_block_len = nixl_agent_meta.block_lens[0]
|
||||
if self.use_mla or self.kv_topo.is_kv_replicated(remote_engine_id):
|
||||
# With replicated KV cache, only the number of blocks can differ.
|
||||
assert self.block_len_per_layer == nixl_agent_meta.block_lens, (
|
||||
"KV cache sizes must match between P and D when replicated"
|
||||
)
|
||||
remote_block_size = remote_block_len // (self.slot_size_per_layer[0])
|
||||
for i in range(len(self.block_len_per_layer)):
|
||||
assert (
|
||||
self.block_len_per_layer[i] // block_size_ratio
|
||||
== nixl_agent_meta.block_lens[i]
|
||||
), "KV cache sizes must match between P and D when replicated"
|
||||
else:
|
||||
# When MLA is not used, this is a list of the same block length
|
||||
for block_len in nixl_agent_meta.block_lens:
|
||||
assert block_len == remote_block_len, (
|
||||
"All remote layers must have the same block size"
|
||||
)
|
||||
remote_block_size = remote_block_len // (
|
||||
self.slot_size_per_layer[0] * tp_ratio
|
||||
)
|
||||
if self.kv_topo.is_kv_layout_blocks_first:
|
||||
# With flashinfer, KV are sent in the same message.
|
||||
remote_block_size //= 2
|
||||
|
||||
assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, (
|
||||
assert (
|
||||
remote_block_len
|
||||
== (self.block_len_per_layer[0] * tp_ratio) // block_size_ratio
|
||||
), (
|
||||
"Remote P worker KV layer cache must be of shape [2, N, "
|
||||
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
|
||||
)
|
||||
|
||||
assert self.block_size == remote_block_size, (
|
||||
"Remote P worker with different page/block size is not supported "
|
||||
f"{self.block_size=}, {remote_block_size=}"
|
||||
)
|
||||
|
||||
# TP workers have same #blocks.
|
||||
assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks
|
||||
|
||||
@ -1576,6 +1647,56 @@ class NixlConnectorWorker:
|
||||
)
|
||||
cache.index_copy_(0, indices, permuted_blocks)
|
||||
|
||||
def blocksize_post_process(self, block_ids_per_ratio: dict[float, list[list[int]]]):
|
||||
def _process_local_gt_remote(blocks_to_update, block_size_ratio):
|
||||
n_kv_heads, block_size, head_size = blocks_to_update.shape[1:]
|
||||
remote_block_size = block_size // block_size_ratio
|
||||
n_blocks = block_size_ratio
|
||||
# actual permute is to convert
|
||||
# for local blocksize > remote blocksize
|
||||
# ex: local blocksize = 16 tokens, remote blocksize = 4 tokens
|
||||
# local block[0] = remote block[0, 1, 2, 3]
|
||||
# remote is |h0-b0|h1-b0|h2-b0|h3-b0|h0-b1|h1-b1|h2-b1|h3-b1|...
|
||||
# local is |h0-b0..................|h1-b0..................|...
|
||||
# permute is to:
|
||||
# 1. view => view remote as n_blocks * remote_shape(H,remoteN,D)
|
||||
# 2. permute => (H, nblocks, remoteN, D)
|
||||
# 3. flatten => (H, localN, D)
|
||||
permuted_blocks = (
|
||||
blocks_to_update.reshape(
|
||||
-1, n_blocks, n_kv_heads, remote_block_size, head_size
|
||||
)
|
||||
.permute(0, 2, 1, 3, 4)
|
||||
.flatten(2, 3)
|
||||
)
|
||||
return permuted_blocks
|
||||
|
||||
if len(self.device_kv_caches) == 0:
|
||||
return
|
||||
split_k_and_v = not (
|
||||
self.use_mla or self._use_pallas or self.kv_topo.is_kv_layout_blocks_first
|
||||
)
|
||||
sample_cache = list(self.device_kv_caches.values())[0][0]
|
||||
for block_size_ratio, block_ids_list in block_ids_per_ratio.items():
|
||||
assert block_size_ratio > 1, "Only nP < nD supported currently."
|
||||
block_ids_list = [[item for sublist in block_ids_list for item in sublist]]
|
||||
|
||||
for block_ids in block_ids_list:
|
||||
indices = torch.tensor(block_ids, device=sample_cache.device)
|
||||
|
||||
for _, cache_or_caches in self.device_kv_caches.items():
|
||||
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
|
||||
for cache in cache_list:
|
||||
blocks_to_update = cache.index_select(0, indices)
|
||||
# because kv_cache is always using original layout NHD as
|
||||
# virtual shape while stride can be either HND / NHD at
|
||||
# initialization.
|
||||
# we need to firstly get physical view of the tensor
|
||||
permuted_blocks = _process_local_gt_remote(
|
||||
blocks_to_update.permute(0, 2, 1, 3), block_size_ratio
|
||||
).permute(0, 2, 1, 3)
|
||||
cache.index_copy_(0, indices, permuted_blocks)
|
||||
|
||||
def get_finished(self) -> tuple[set[str], set[str]]:
|
||||
"""
|
||||
Get requests that are done sending or recving on this specific worker.
|
||||
@ -1599,6 +1720,7 @@ class NixlConnectorWorker:
|
||||
)
|
||||
|
||||
block_ids_to_permute = []
|
||||
block_ids_for_blocksize_post_process = defaultdict(list)
|
||||
for req_id in done_recving:
|
||||
# clean up metadata for completed requests
|
||||
meta = self._recving_metadata.pop(req_id, None)
|
||||
@ -1607,6 +1729,20 @@ class NixlConnectorWorker:
|
||||
self.sync_recved_kv_to_device(req_id, meta)
|
||||
if self.enable_permute_local_kv:
|
||||
block_ids_to_permute += meta.local_physical_block_ids
|
||||
|
||||
# post processing for heteroblocksize
|
||||
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(
|
||||
meta.remote_engine_id
|
||||
)
|
||||
if (
|
||||
not self.use_mla
|
||||
and block_size_ratio > 1
|
||||
and self.kv_cache_layout == "HND"
|
||||
):
|
||||
block_ids_for_blocksize_post_process[block_size_ratio].append(
|
||||
meta.local_block_ids
|
||||
)
|
||||
self.blocksize_post_process(block_ids_for_blocksize_post_process)
|
||||
if len(block_ids_to_permute) > 0:
|
||||
self.permute_device_kv(block_ids_to_permute)
|
||||
|
||||
@ -1781,6 +1917,24 @@ class NixlConnectorWorker:
|
||||
dst_engine_id: str,
|
||||
request_id: str,
|
||||
):
|
||||
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id)
|
||||
if block_size_ratio > 1:
|
||||
local_block_ids = self.get_mapped_blocks(
|
||||
np.asarray(local_block_ids), block_size_ratio
|
||||
)
|
||||
if len(local_block_ids) > len(remote_block_ids):
|
||||
# NOTE:
|
||||
# get_mapped_blocks will always expand block_ids for n times.
|
||||
# ex:
|
||||
# prefill block_ids with block_size as 4:
|
||||
# [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
# Local decode block_ids with block_size as 16: [1, 2, 3]
|
||||
# expland ecode block_ids with get_mapped_blocks from [1, 2, 3] to
|
||||
# [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
|
||||
# Then we clip local to align with prefill
|
||||
# [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] to
|
||||
# [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
local_block_ids = local_block_ids[: len(remote_block_ids)]
|
||||
# NOTE(rob): having the staging blocks be on the READER side is
|
||||
# not going to work well (since we will have to call rearrange tensors).
|
||||
# after we detect the txn is complete (which means we cannot make the
|
||||
@ -1823,7 +1977,10 @@ class NixlConnectorWorker:
|
||||
remote_block_ids = remote_block_ids[-num_local_blocks:]
|
||||
|
||||
# Get side handles.
|
||||
local_xfer_side_handle = self.src_xfer_side_handle
|
||||
remote_block_size = self.kv_topo.remote_block_size[dst_engine_id]
|
||||
local_xfer_side_handle = self.src_xfer_side_handles.get(
|
||||
remote_block_size, self.src_xfer_side_handle
|
||||
)
|
||||
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id]
|
||||
|
||||
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
|
||||
@ -1833,13 +1990,17 @@ class NixlConnectorWorker:
|
||||
# Get descs ids.
|
||||
local_block_descs_ids: np.ndarray
|
||||
remote_block_descs_ids: np.ndarray
|
||||
|
||||
if not self.block_window_per_layer:
|
||||
# Default case: assume global attention
|
||||
remote_block_descs_ids = self._get_block_descs_ids(
|
||||
dst_engine_id, remote_block_ids
|
||||
dst_engine_id,
|
||||
remote_block_ids,
|
||||
)
|
||||
local_block_descs_ids = self._get_block_descs_ids(
|
||||
self.engine_id, local_block_ids
|
||||
self.engine_id,
|
||||
local_block_ids,
|
||||
block_size_ratio=block_size_ratio,
|
||||
)
|
||||
else:
|
||||
# TODO(mgoin): remove this once we have hybrid memory allocator
|
||||
@ -1860,10 +2021,15 @@ class NixlConnectorWorker:
|
||||
|
||||
# Get descs ids for the layer.
|
||||
layer_local_desc_ids = self._get_block_descs_ids(
|
||||
self.engine_id, layer_local_block_ids, layer_idx
|
||||
dst_engine_id,
|
||||
layer_local_block_ids,
|
||||
layer_idx,
|
||||
)
|
||||
layer_remote_desc_ids = self._get_block_descs_ids(
|
||||
dst_engine_id, layer_remote_block_ids, layer_idx
|
||||
self.engine_id,
|
||||
layer_remote_block_ids,
|
||||
layer_idx,
|
||||
block_size_ratio=block_size_ratio,
|
||||
)
|
||||
|
||||
local_descs_list.append(layer_local_desc_ids)
|
||||
@ -1905,8 +2071,31 @@ class NixlConnectorWorker:
|
||||
self.nixl_wrapper.release_xfer_handle(handle)
|
||||
self._failed_recv_reqs.add(request_id)
|
||||
|
||||
def get_mapped_blocks(self, block_ids, block_size_ratio):
|
||||
"""
|
||||
Calculates the new set of block IDs by mapping every element
|
||||
in the (potentially sparse) input array.
|
||||
Example: block_ids=[0, 2], block_size_ratio=2
|
||||
get_mapped_blocks 0 1 [2 3] 4 5
|
||||
# remote is |h0-b0|h1-b0||h0-b1|h1-b1||h0-b1|h1-b1||
|
||||
# local is |h0-b0......||h1-b0......||h2-b0........
|
||||
local_block_ids 0 [1] 2
|
||||
"""
|
||||
if block_ids.size == 0:
|
||||
return np.array([], dtype=np.int64)
|
||||
|
||||
start_ids = block_ids * block_size_ratio
|
||||
offsets = np.arange(block_size_ratio)
|
||||
mapped_2d = start_ids[:, None] + offsets[None, :]
|
||||
|
||||
return mapped_2d.flatten().astype(np.int64)
|
||||
|
||||
def _get_block_descs_ids(
|
||||
self, engine_id: str, block_ids: list[int], layer_idx: int | None = None
|
||||
self,
|
||||
engine_id: str,
|
||||
block_ids: list[int],
|
||||
layer_idx: int | None = None,
|
||||
block_size_ratio: float | None = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Get the descs ids for a set of block ids.
|
||||
@ -1929,6 +2118,8 @@ class NixlConnectorWorker:
|
||||
region_ids = np.arange(layer_idx, layer_idx + 1)
|
||||
|
||||
num_blocks = self.dst_num_blocks[engine_id]
|
||||
if block_size_ratio is not None:
|
||||
num_blocks = int(num_blocks * block_size_ratio)
|
||||
|
||||
# Compute the desc ids for each block.
|
||||
region_ids = region_ids[:, None]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user