mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:15:01 +08:00
[Nixl] Heterogeneous TP support FlashInfer (#20189)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
f38035c123
commit
f0c503f66e
@ -715,7 +715,7 @@ class NixlConnectorWorker:
|
||||
# are non-contiguous (it's not locally guaranteed that they will be)
|
||||
# Disadvantage is that the encoded NixlAgentMetadata is now larger
|
||||
# (roughly 8KB vs 5KB).
|
||||
# Conversely for FlashInfer, K and V are transferred in the same tensor
|
||||
# Conversely for FlashInfer, K and V are registered in the same region
|
||||
# to better exploit the memory layout (ie num_blocks is the first dim).
|
||||
split_k_and_v = not (self.use_mla or self._use_pallas_v1
|
||||
or self._use_flashinfer)
|
||||
@ -758,12 +758,21 @@ class NixlConnectorWorker:
|
||||
assert tensor_size_bytes % self.num_blocks == 0
|
||||
self.block_len = tensor_size_bytes // self.num_blocks
|
||||
self.slot_size_bytes = self.block_len // self.block_size
|
||||
self.device_kv_caches = kv_caches
|
||||
self.dst_num_blocks[self.engine_id] = self.num_blocks
|
||||
if self._use_flashinfer:
|
||||
assert self.slot_size_bytes % 2 == 0
|
||||
self.slot_size_bytes /= 2
|
||||
self.device_kv_caches = kv_caches
|
||||
self.dst_num_blocks[self.engine_id] = self.num_blocks
|
||||
|
||||
# NOTE (NickLucche) When FlashInfer is used, memory is registered
|
||||
# with joint KV for each block. This minimizes the overhead in
|
||||
# registerMem allowing faster descs queries. In order to be able to
|
||||
# split on kv_heads dim as required by heterogeneous TP, one must
|
||||
# be able to index K/V separately. Hence the we double the number
|
||||
# of 'virtual' regions here and halve `block_len` below.
|
||||
self.num_regions *= 2
|
||||
|
||||
kv_block_len = self.get_backend_aware_kv_block_len()
|
||||
# Register local/src descr for NIXL xfer.
|
||||
blocks_data = []
|
||||
for base_addr in seen_base_addresses:
|
||||
@ -776,8 +785,18 @@ class NixlConnectorWorker:
|
||||
block_offset = block_id * self.block_len
|
||||
addr = base_addr + block_offset
|
||||
# (addr, len, device id)
|
||||
# TODO: does device_id matter to DRAM?
|
||||
blocks_data.append((addr, self.block_len, self.tp_rank))
|
||||
blocks_data.append((addr, kv_block_len, self.tp_rank))
|
||||
|
||||
if self._use_flashinfer:
|
||||
# Separate and interleave K/V regions to maintain the same
|
||||
# descs ordering. This is needed for selecting contiguous heads
|
||||
# when split across TP ranks.
|
||||
for block_id in range(self.num_blocks):
|
||||
block_offset = block_id * self.block_len
|
||||
addr = base_addr + block_offset
|
||||
# Register addresses for V cache (K registered first).
|
||||
v_addr = addr + kv_block_len
|
||||
blocks_data.append((v_addr, kv_block_len, self.tp_rank))
|
||||
logger.debug("Created %s blocks for src engine %s and rank %s",
|
||||
len(blocks_data), self.engine_id, self.tp_rank)
|
||||
|
||||
@ -903,7 +922,7 @@ class NixlConnectorWorker:
|
||||
remote_block_size = nixl_agent_meta.block_len // (
|
||||
self.slot_size_bytes * tp_ratio)
|
||||
if self._use_flashinfer:
|
||||
# Account for joint KV in FlashInfer.
|
||||
# With flashinfer, KV are sent in the same message.
|
||||
remote_block_size //= 2
|
||||
if tp_ratio > 1:
|
||||
# Heterogeneous TP expects same kv_cache_layout.
|
||||
@ -929,10 +948,10 @@ class NixlConnectorWorker:
|
||||
# rank. With heterogeneous TP, prepare the descriptors by splitting the
|
||||
# P KV cache along kv_head dim, of D worker's kv_head size (D>P).
|
||||
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
|
||||
# Only register the remote's descriptors if current rank pulls from it.
|
||||
self.kv_caches_base_addr[
|
||||
engine_id] = nixl_agent_meta.kv_caches_base_addr
|
||||
rank_offset = self.tp_rank % tp_ratio * self.block_len \
|
||||
kv_block_len = self.get_backend_aware_kv_block_len()
|
||||
rank_offset = self.tp_rank % tp_ratio * kv_block_len \
|
||||
if not (self.use_mla or is_kv_replicated) else 0
|
||||
# Register all remote blocks, but only the corresponding kv heads.
|
||||
for base_addr in nixl_agent_meta.kv_caches_base_addr:
|
||||
@ -943,7 +962,16 @@ class NixlConnectorWorker:
|
||||
# self.block_len == remote_block_len//tp_ratio bytes.
|
||||
addr = base_addr + block_offset + rank_offset
|
||||
# (addr, len, device id)
|
||||
blocks_data.append((addr, self.block_len, remote_tp_rank))
|
||||
blocks_data.append((addr, kv_block_len, remote_tp_rank))
|
||||
|
||||
if self._use_flashinfer:
|
||||
# With FlashInfer index V separately to allow head splitting.
|
||||
for block_id in range(nixl_agent_meta.num_blocks):
|
||||
block_offset = block_id * nixl_agent_meta.block_len
|
||||
addr = base_addr + block_offset + rank_offset
|
||||
v_addr = addr + nixl_agent_meta.block_len // 2
|
||||
blocks_data.append((v_addr, kv_block_len, remote_tp_rank))
|
||||
|
||||
logger.debug(
|
||||
"Created %s blocks for dst engine %s with remote rank %s and "
|
||||
"local rank %s", len(blocks_data), engine_id, remote_tp_rank,
|
||||
@ -1249,6 +1277,22 @@ class NixlConnectorWorker:
|
||||
descs_ids.append(reg_id * num_blocks + block_id)
|
||||
return descs_ids
|
||||
|
||||
def get_backend_aware_kv_block_len(self):
|
||||
"""
|
||||
Get the block length for one K/V element (K and V have the same size).
|
||||
|
||||
For FA and other backends, this is equal to the length of the whole
|
||||
block, as K and V are in separate regions.
|
||||
For FlashInfer, this is half the length of the whole block, as K and V
|
||||
share the same region.
|
||||
"""
|
||||
if self._use_flashinfer:
|
||||
# For indexing only half (either just the K or V part).
|
||||
block_len = self.block_len // 2
|
||||
else:
|
||||
block_len = self.block_len
|
||||
return block_len
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user