[NIXL] heterogeneous block_size support (#26759)

Signed-off-by: Chendi Xue <chendi.xue@intel.com>
Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com>
This commit is contained in:
Chendi.Xue 2025-11-14 23:51:32 -06:00 committed by GitHub
parent 363aaeef0f
commit c9e665852a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 257 additions and 59 deletions

View File

@ -49,6 +49,8 @@ NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-1} # Default to 1
PREFILLER_TP_SIZE=${PREFILLER_TP_SIZE:-1}
DECODER_TP_SIZE=${DECODER_TP_SIZE:-1}
GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.2}
PREFILL_BLOCK_SIZE=${PREFILL_BLOCK_SIZE:-16}
DECODE_BLOCK_SIZE=${DECODE_BLOCK_SIZE:-16}
# Find the git repository root directory
GIT_ROOT=$(git rev-parse --show-toplevel)
@ -136,6 +138,7 @@ run_tests_for_model() {
vllm serve $model_name \
--port $PORT \
--enforce-eager \
--block-size ${PREFILL_BLOCK_SIZE} \
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
--tensor-parallel-size $PREFILLER_TP_SIZE \
--kv-transfer-config '$KV_CONFIG'"
@ -177,6 +180,7 @@ run_tests_for_model() {
vllm serve $model_name \
--port $PORT \
--enforce-eager \
--block-size ${DECODE_BLOCK_SIZE} \
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
--kv-transfer-config '$KV_CONFIG'"

View File

@ -407,6 +407,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
# `self.kv_cache_layout` is only forced to HND when vllm engine
# is started. We mock HND here.
kv_cache_layout="HND",
block_size=self.block_size,
),
remote_tp_size=remote_tp_size,
)
@ -652,6 +653,7 @@ class TestNixlHandshake:
block_lens=worker.block_len_per_layer,
attn_backend_name=worker.backend_name,
kv_cache_layout=mismatched_layout,
block_size=worker.block_size,
)
with pytest.raises(RuntimeError):
@ -706,6 +708,7 @@ class TestNixlHandshake:
block_lens=[i * 2 for i in worker.block_len_per_layer],
attn_backend_name=worker.backend_name,
kv_cache_layout="HND",
block_size=worker.block_size,
)
# We don't check layout for homogeneous TP and MLA for now, as the

View File

@ -108,6 +108,7 @@ class NixlAgentMetadata(KVConnectorHandshakeMetadata):
block_lens: list[int]
attn_backend_name: str
kv_cache_layout: str
block_size: int
@dataclass
@ -709,6 +710,9 @@ class NixlConnectorWorker:
self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first
)
block_size: int
remote_block_size: dict[EngineId, int]
def tp_ratio(
self,
remote_tp_size: int,
@ -725,6 +729,19 @@ class NixlConnectorWorker:
)
return self.tp_size // remote_tp_size
def block_size_ratio(
self,
remote_block_size: int,
) -> float:
"""
Calculate the block size ratio between local and remote TP.
"""
assert self.block_size % remote_block_size == 0, (
f"Local block size {self.block_size} is not divisible "
f"by remote block size {remote_block_size} or vice versa."
)
return self.block_size // remote_block_size
def tp_ratio_from_engine_id(
self,
remote_engine_id: EngineId,
@ -732,6 +749,13 @@ class NixlConnectorWorker:
remote_tp_size = self.remote_tp_size[remote_engine_id]
return self.tp_ratio(remote_tp_size)
def block_size_ratio_from_engine_id(
self,
remote_engine_id: EngineId,
) -> float:
remote_block_size = self.remote_block_size[remote_engine_id]
return self.block_size_ratio(remote_block_size)
def is_kv_replicated(self, engine_id: EngineId) -> bool:
"""
Whether the KV cache is replicated across TP workers due to the
@ -866,6 +890,7 @@ class NixlConnectorWorker:
# nixl_prepped_dlist_handle.
self.src_xfer_side_handle: int = 0
self.src_xfer_side_handles: dict[int, int] = {}
# Map of engine_id -> nixl_prepped_dlist_handle (int)].
self.dst_xfer_side_handles: dict[EngineId, int] = {}
@ -925,6 +950,7 @@ class NixlConnectorWorker:
logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size}
# With heterogeneous TP, P must wait for all assigned D TP workers to
# finish reading before safely freeing the blocks.
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
@ -936,6 +962,8 @@ class NixlConnectorWorker:
remote_tp_size=self._tp_size, # shared state
is_mla=self.use_mla,
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
block_size=self.block_size,
remote_block_size=self._block_size,
attn_backend=backend,
)
self._use_pallas = self.kv_topo._use_pallas
@ -987,9 +1015,13 @@ class NixlConnectorWorker:
)
# Register Remote agent.
assert metadata.block_size <= self.block_size, (
"nP > nD is not supported yet."
)
remote_agent_name = self.add_remote_agent(
metadata, p_remote_rank, remote_tp_size
)
setup_agent_time = time.perf_counter()
logger.debug(
"NIXL handshake: add agent took: %s",
@ -1217,43 +1249,10 @@ class NixlConnectorWorker:
self.num_regions *= 2
# Register local/src descr for NIXL xfer.
blocks_data = []
for i, base_addr in enumerate(seen_base_addresses):
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
# NOTE With heter-TP, more blocks are prepared than what are
# needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
# could create fewer, but then _get_block_descs_ids needs to
# select agent_meta.num_blocks instead of self.num_blocks for
# local descr, and that makes handling regular flow less clean.
for block_id in range(self.num_blocks):
block_offset = block_id * self.block_len_per_layer[i]
addr = base_addr + block_offset
# (addr, len, device id)
blocks_data.append((addr, kv_block_len, self.device_id))
self.seen_base_addresses = seen_base_addresses
self.src_xfer_side_handle = self.register_local_xfer_handler(self.block_size)
if self.kv_topo.is_kv_layout_blocks_first:
# Separate and interleave K/V regions to maintain the same
# descs ordering. This is needed for selecting contiguous heads
# when split across TP ranks.
for block_id in range(self.num_blocks):
block_offset = block_id * self.block_len_per_layer[i]
addr = base_addr + block_offset
# Register addresses for V cache (K registered first).
v_addr = addr + kv_block_len
blocks_data.append((v_addr, kv_block_len, self.device_id))
logger.debug(
"Created %s blocks for src engine %s and rank %s on device id %s",
len(blocks_data),
self.engine_id,
self.tp_rank,
self.device_id,
)
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
# NIXL_INIT_AGENT to be used for preparations of local descs.
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
"NIXL_INIT_AGENT", descs
)
self.src_xfer_side_handles[self.block_size] = self.src_xfer_side_handle
# TODO(mgoin): Hybrid memory allocator is currently disabled for
# models with local attention (Llama 4). Can remove this once enabled.
@ -1289,8 +1288,62 @@ class NixlConnectorWorker:
kv_cache_layout=self.kv_cache_layout
if not self.use_host_buffer
else self.host_buffer_kv_cache_layout,
block_size=self.block_size,
)
def register_local_xfer_handler(
self,
block_size: int,
) -> int:
"""
Function used for register local xfer handler with local block_size or
Remote block_size.
When local block_size is same as remote block_size, we use local block_size
to register local_xfer_handler during init.
When remote block size is less than local block size, we need to use
register another local_xfer_handler using remote block len to ensure
data copy correctness.
"""
block_size_ratio = self.block_size // block_size
blocks_data = []
for i, base_addr in enumerate(self.seen_base_addresses):
# The new block_len is using prefill block_len;
# and num_blocks is multiple with N
kv_block_len = (
self.get_backend_aware_kv_block_len(layer_idx=i) // block_size_ratio
)
block_len_per_layer = self.block_len_per_layer[i] // block_size_ratio
num_blocks = self.num_blocks * block_size_ratio
for block_id in range(num_blocks):
block_offset = block_id * block_len_per_layer
addr = base_addr + block_offset
# (addr, len, device id)
blocks_data.append((addr, kv_block_len, self.device_id))
if self.kv_topo.is_kv_layout_blocks_first:
# Separate and interleave K/V regions to maintain the same
# descs ordering. This is needed for selecting contiguous heads
# when split across TP ranks.
for block_id in range(num_blocks):
block_offset = block_id * block_len_per_layer
addr = base_addr + block_offset
# Register addresses for V cache (K registered first).
v_addr = addr + kv_block_len
blocks_data.append((v_addr, kv_block_len, self.device_id))
logger.debug(
"Created %s blocks for src engine %s and rank %s on device id %s",
len(blocks_data),
self.engine_id,
self.tp_rank,
self.device_id,
)
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
# NIXL_INIT_AGENT to be used for preparations of local descs.
return self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs)
def add_remote_agent(
self,
nixl_agent_meta: NixlAgentMetadata,
@ -1349,6 +1402,8 @@ class NixlConnectorWorker:
### Register remote agent metadata
if engine_id not in self._tp_size:
self._tp_size[engine_id] = remote_tp_size
if engine_id not in self._block_size:
self._block_size[engine_id] = nixl_agent_meta.block_size
remote_agent_name = self.nixl_wrapper.add_remote_agent(
nixl_agent_meta.agent_metadata
@ -1359,6 +1414,13 @@ class NixlConnectorWorker:
# Create dst descs and xfer side handles. TP workers have same #blocks
# so we only register once per engine_id.
# Example:
# block_size_ratio > 1:
# remote: | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|
# local origin:| 0| 1| 8| 12|
# local mapped:| 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|13|14|15|
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(engine_id)
if engine_id not in self.dst_num_blocks:
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
@ -1381,8 +1443,14 @@ class NixlConnectorWorker:
# Register all remote blocks, but only the corresponding kv heads.
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
remote_kv_block_len = kv_block_len // block_size_ratio
if block_size_ratio > 1:
# using remote kv_block_len as transfer unit
kv_block_len = remote_kv_block_len
rank_offset = (
self.tp_rank % tp_ratio * kv_block_len if not replicates_kv_cache else 0
self.tp_rank % tp_ratio * remote_kv_block_len
if not replicates_kv_cache
else 0
)
for block_id in range(nixl_agent_meta.num_blocks):
block_offset = block_id * nixl_agent_meta.block_lens[i]
@ -1417,6 +1485,13 @@ class NixlConnectorWorker:
remote_agent_name, descs
)
if block_size_ratio > 1:
# when prefill with smaller block_size, we need to init a
# new handler with same block_len to match
self.src_xfer_side_handles[nixl_agent_meta.block_size] = (
self.register_local_xfer_handler(nixl_agent_meta.block_size)
)
return remote_agent_name
def _validate_remote_agent_handshake(
@ -1433,6 +1508,9 @@ class NixlConnectorWorker:
assert nixl_agent_meta.attn_backend_name == self.backend_name
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(remote_engine_id)
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(
remote_engine_id
)
assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP"
assert not self._use_pallas or tp_ratio == 1, (
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
@ -1463,33 +1541,26 @@ class NixlConnectorWorker:
remote_block_len = nixl_agent_meta.block_lens[0]
if self.use_mla or self.kv_topo.is_kv_replicated(remote_engine_id):
# With replicated KV cache, only the number of blocks can differ.
assert self.block_len_per_layer == nixl_agent_meta.block_lens, (
"KV cache sizes must match between P and D when replicated"
)
remote_block_size = remote_block_len // (self.slot_size_per_layer[0])
for i in range(len(self.block_len_per_layer)):
assert (
self.block_len_per_layer[i] // block_size_ratio
== nixl_agent_meta.block_lens[i]
), "KV cache sizes must match between P and D when replicated"
else:
# When MLA is not used, this is a list of the same block length
for block_len in nixl_agent_meta.block_lens:
assert block_len == remote_block_len, (
"All remote layers must have the same block size"
)
remote_block_size = remote_block_len // (
self.slot_size_per_layer[0] * tp_ratio
)
if self.kv_topo.is_kv_layout_blocks_first:
# With flashinfer, KV are sent in the same message.
remote_block_size //= 2
assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, (
assert (
remote_block_len
== (self.block_len_per_layer[0] * tp_ratio) // block_size_ratio
), (
"Remote P worker KV layer cache must be of shape [2, N, "
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
)
assert self.block_size == remote_block_size, (
"Remote P worker with different page/block size is not supported "
f"{self.block_size=}, {remote_block_size=}"
)
# TP workers have same #blocks.
assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks
@ -1576,6 +1647,56 @@ class NixlConnectorWorker:
)
cache.index_copy_(0, indices, permuted_blocks)
def blocksize_post_process(self, block_ids_per_ratio: dict[float, list[list[int]]]):
def _process_local_gt_remote(blocks_to_update, block_size_ratio):
n_kv_heads, block_size, head_size = blocks_to_update.shape[1:]
remote_block_size = block_size // block_size_ratio
n_blocks = block_size_ratio
# actual permute is to convert
# for local blocksize > remote blocksize
# ex: local blocksize = 16 tokens, remote blocksize = 4 tokens
# local block[0] = remote block[0, 1, 2, 3]
# remote is |h0-b0|h1-b0|h2-b0|h3-b0|h0-b1|h1-b1|h2-b1|h3-b1|...
# local is |h0-b0..................|h1-b0..................|...
# permute is to:
# 1. view => view remote as n_blocks * remote_shape(H,remoteN,D)
# 2. permute => (H, nblocks, remoteN, D)
# 3. flatten => (H, localN, D)
permuted_blocks = (
blocks_to_update.reshape(
-1, n_blocks, n_kv_heads, remote_block_size, head_size
)
.permute(0, 2, 1, 3, 4)
.flatten(2, 3)
)
return permuted_blocks
if len(self.device_kv_caches) == 0:
return
split_k_and_v = not (
self.use_mla or self._use_pallas or self.kv_topo.is_kv_layout_blocks_first
)
sample_cache = list(self.device_kv_caches.values())[0][0]
for block_size_ratio, block_ids_list in block_ids_per_ratio.items():
assert block_size_ratio > 1, "Only nP < nD supported currently."
block_ids_list = [[item for sublist in block_ids_list for item in sublist]]
for block_ids in block_ids_list:
indices = torch.tensor(block_ids, device=sample_cache.device)
for _, cache_or_caches in self.device_kv_caches.items():
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
for cache in cache_list:
blocks_to_update = cache.index_select(0, indices)
# because kv_cache is always using original layout NHD as
# virtual shape while stride can be either HND / NHD at
# initialization.
# we need to firstly get physical view of the tensor
permuted_blocks = _process_local_gt_remote(
blocks_to_update.permute(0, 2, 1, 3), block_size_ratio
).permute(0, 2, 1, 3)
cache.index_copy_(0, indices, permuted_blocks)
def get_finished(self) -> tuple[set[str], set[str]]:
"""
Get requests that are done sending or recving on this specific worker.
@ -1599,6 +1720,7 @@ class NixlConnectorWorker:
)
block_ids_to_permute = []
block_ids_for_blocksize_post_process = defaultdict(list)
for req_id in done_recving:
# clean up metadata for completed requests
meta = self._recving_metadata.pop(req_id, None)
@ -1607,6 +1729,20 @@ class NixlConnectorWorker:
self.sync_recved_kv_to_device(req_id, meta)
if self.enable_permute_local_kv:
block_ids_to_permute += meta.local_physical_block_ids
# post processing for heteroblocksize
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(
meta.remote_engine_id
)
if (
not self.use_mla
and block_size_ratio > 1
and self.kv_cache_layout == "HND"
):
block_ids_for_blocksize_post_process[block_size_ratio].append(
meta.local_block_ids
)
self.blocksize_post_process(block_ids_for_blocksize_post_process)
if len(block_ids_to_permute) > 0:
self.permute_device_kv(block_ids_to_permute)
@ -1781,6 +1917,24 @@ class NixlConnectorWorker:
dst_engine_id: str,
request_id: str,
):
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id)
if block_size_ratio > 1:
local_block_ids = self.get_mapped_blocks(
np.asarray(local_block_ids), block_size_ratio
)
if len(local_block_ids) > len(remote_block_ids):
# NOTE:
# get_mapped_blocks will always expand block_ids for n times.
# ex:
# prefill block_ids with block_size as 4:
# [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
# Local decode block_ids with block_size as 16: [1, 2, 3]
# expland ecode block_ids with get_mapped_blocks from [1, 2, 3] to
# [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
# Then we clip local to align with prefill
# [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] to
# [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
local_block_ids = local_block_ids[: len(remote_block_ids)]
# NOTE(rob): having the staging blocks be on the READER side is
# not going to work well (since we will have to call rearrange tensors).
# after we detect the txn is complete (which means we cannot make the
@ -1823,7 +1977,10 @@ class NixlConnectorWorker:
remote_block_ids = remote_block_ids[-num_local_blocks:]
# Get side handles.
local_xfer_side_handle = self.src_xfer_side_handle
remote_block_size = self.kv_topo.remote_block_size[dst_engine_id]
local_xfer_side_handle = self.src_xfer_side_handles.get(
remote_block_size, self.src_xfer_side_handle
)
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id]
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
@ -1833,13 +1990,17 @@ class NixlConnectorWorker:
# Get descs ids.
local_block_descs_ids: np.ndarray
remote_block_descs_ids: np.ndarray
if not self.block_window_per_layer:
# Default case: assume global attention
remote_block_descs_ids = self._get_block_descs_ids(
dst_engine_id, remote_block_ids
dst_engine_id,
remote_block_ids,
)
local_block_descs_ids = self._get_block_descs_ids(
self.engine_id, local_block_ids
self.engine_id,
local_block_ids,
block_size_ratio=block_size_ratio,
)
else:
# TODO(mgoin): remove this once we have hybrid memory allocator
@ -1860,10 +2021,15 @@ class NixlConnectorWorker:
# Get descs ids for the layer.
layer_local_desc_ids = self._get_block_descs_ids(
self.engine_id, layer_local_block_ids, layer_idx
dst_engine_id,
layer_local_block_ids,
layer_idx,
)
layer_remote_desc_ids = self._get_block_descs_ids(
dst_engine_id, layer_remote_block_ids, layer_idx
self.engine_id,
layer_remote_block_ids,
layer_idx,
block_size_ratio=block_size_ratio,
)
local_descs_list.append(layer_local_desc_ids)
@ -1905,8 +2071,31 @@ class NixlConnectorWorker:
self.nixl_wrapper.release_xfer_handle(handle)
self._failed_recv_reqs.add(request_id)
def get_mapped_blocks(self, block_ids, block_size_ratio):
"""
Calculates the new set of block IDs by mapping every element
in the (potentially sparse) input array.
Example: block_ids=[0, 2], block_size_ratio=2
get_mapped_blocks 0 1 [2 3] 4 5
# remote is |h0-b0|h1-b0||h0-b1|h1-b1||h0-b1|h1-b1||
# local is |h0-b0......||h1-b0......||h2-b0........
local_block_ids 0 [1] 2
"""
if block_ids.size == 0:
return np.array([], dtype=np.int64)
start_ids = block_ids * block_size_ratio
offsets = np.arange(block_size_ratio)
mapped_2d = start_ids[:, None] + offsets[None, :]
return mapped_2d.flatten().astype(np.int64)
def _get_block_descs_ids(
self, engine_id: str, block_ids: list[int], layer_idx: int | None = None
self,
engine_id: str,
block_ids: list[int],
layer_idx: int | None = None,
block_size_ratio: float | None = None,
) -> np.ndarray:
"""
Get the descs ids for a set of block ids.
@ -1929,6 +2118,8 @@ class NixlConnectorWorker:
region_ids = np.arange(layer_idx, layer_idx + 1)
num_blocks = self.dst_num_blocks[engine_id]
if block_size_ratio is not None:
num_blocks = int(num_blocks * block_size_ratio)
# Compute the desc ids for each block.
region_ids = region_ids[:, None]