mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-20 03:35:01 +08:00
[Nixl] Heterogeneous TP support FlashInfer (#20189)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
f38035c123
commit
f0c503f66e
@ -715,7 +715,7 @@ class NixlConnectorWorker:
|
|||||||
# are non-contiguous (it's not locally guaranteed that they will be)
|
# are non-contiguous (it's not locally guaranteed that they will be)
|
||||||
# Disadvantage is that the encoded NixlAgentMetadata is now larger
|
# Disadvantage is that the encoded NixlAgentMetadata is now larger
|
||||||
# (roughly 8KB vs 5KB).
|
# (roughly 8KB vs 5KB).
|
||||||
# Conversely for FlashInfer, K and V are transferred in the same tensor
|
# Conversely for FlashInfer, K and V are registered in the same region
|
||||||
# to better exploit the memory layout (ie num_blocks is the first dim).
|
# to better exploit the memory layout (ie num_blocks is the first dim).
|
||||||
split_k_and_v = not (self.use_mla or self._use_pallas_v1
|
split_k_and_v = not (self.use_mla or self._use_pallas_v1
|
||||||
or self._use_flashinfer)
|
or self._use_flashinfer)
|
||||||
@ -758,12 +758,21 @@ class NixlConnectorWorker:
|
|||||||
assert tensor_size_bytes % self.num_blocks == 0
|
assert tensor_size_bytes % self.num_blocks == 0
|
||||||
self.block_len = tensor_size_bytes // self.num_blocks
|
self.block_len = tensor_size_bytes // self.num_blocks
|
||||||
self.slot_size_bytes = self.block_len // self.block_size
|
self.slot_size_bytes = self.block_len // self.block_size
|
||||||
|
self.device_kv_caches = kv_caches
|
||||||
|
self.dst_num_blocks[self.engine_id] = self.num_blocks
|
||||||
if self._use_flashinfer:
|
if self._use_flashinfer:
|
||||||
assert self.slot_size_bytes % 2 == 0
|
assert self.slot_size_bytes % 2 == 0
|
||||||
self.slot_size_bytes /= 2
|
self.slot_size_bytes /= 2
|
||||||
self.device_kv_caches = kv_caches
|
|
||||||
self.dst_num_blocks[self.engine_id] = self.num_blocks
|
|
||||||
|
|
||||||
|
# NOTE (NickLucche) When FlashInfer is used, memory is registered
|
||||||
|
# with joint KV for each block. This minimizes the overhead in
|
||||||
|
# registerMem allowing faster descs queries. In order to be able to
|
||||||
|
# split on kv_heads dim as required by heterogeneous TP, one must
|
||||||
|
# be able to index K/V separately. Hence the we double the number
|
||||||
|
# of 'virtual' regions here and halve `block_len` below.
|
||||||
|
self.num_regions *= 2
|
||||||
|
|
||||||
|
kv_block_len = self.get_backend_aware_kv_block_len()
|
||||||
# Register local/src descr for NIXL xfer.
|
# Register local/src descr for NIXL xfer.
|
||||||
blocks_data = []
|
blocks_data = []
|
||||||
for base_addr in seen_base_addresses:
|
for base_addr in seen_base_addresses:
|
||||||
@ -776,8 +785,18 @@ class NixlConnectorWorker:
|
|||||||
block_offset = block_id * self.block_len
|
block_offset = block_id * self.block_len
|
||||||
addr = base_addr + block_offset
|
addr = base_addr + block_offset
|
||||||
# (addr, len, device id)
|
# (addr, len, device id)
|
||||||
# TODO: does device_id matter to DRAM?
|
blocks_data.append((addr, kv_block_len, self.tp_rank))
|
||||||
blocks_data.append((addr, self.block_len, self.tp_rank))
|
|
||||||
|
if self._use_flashinfer:
|
||||||
|
# Separate and interleave K/V regions to maintain the same
|
||||||
|
# descs ordering. This is needed for selecting contiguous heads
|
||||||
|
# when split across TP ranks.
|
||||||
|
for block_id in range(self.num_blocks):
|
||||||
|
block_offset = block_id * self.block_len
|
||||||
|
addr = base_addr + block_offset
|
||||||
|
# Register addresses for V cache (K registered first).
|
||||||
|
v_addr = addr + kv_block_len
|
||||||
|
blocks_data.append((v_addr, kv_block_len, self.tp_rank))
|
||||||
logger.debug("Created %s blocks for src engine %s and rank %s",
|
logger.debug("Created %s blocks for src engine %s and rank %s",
|
||||||
len(blocks_data), self.engine_id, self.tp_rank)
|
len(blocks_data), self.engine_id, self.tp_rank)
|
||||||
|
|
||||||
@ -903,7 +922,7 @@ class NixlConnectorWorker:
|
|||||||
remote_block_size = nixl_agent_meta.block_len // (
|
remote_block_size = nixl_agent_meta.block_len // (
|
||||||
self.slot_size_bytes * tp_ratio)
|
self.slot_size_bytes * tp_ratio)
|
||||||
if self._use_flashinfer:
|
if self._use_flashinfer:
|
||||||
# Account for joint KV in FlashInfer.
|
# With flashinfer, KV are sent in the same message.
|
||||||
remote_block_size //= 2
|
remote_block_size //= 2
|
||||||
if tp_ratio > 1:
|
if tp_ratio > 1:
|
||||||
# Heterogeneous TP expects same kv_cache_layout.
|
# Heterogeneous TP expects same kv_cache_layout.
|
||||||
@ -929,10 +948,10 @@ class NixlConnectorWorker:
|
|||||||
# rank. With heterogeneous TP, prepare the descriptors by splitting the
|
# rank. With heterogeneous TP, prepare the descriptors by splitting the
|
||||||
# P KV cache along kv_head dim, of D worker's kv_head size (D>P).
|
# P KV cache along kv_head dim, of D worker's kv_head size (D>P).
|
||||||
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
|
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
|
||||||
# Only register the remote's descriptors if current rank pulls from it.
|
|
||||||
self.kv_caches_base_addr[
|
self.kv_caches_base_addr[
|
||||||
engine_id] = nixl_agent_meta.kv_caches_base_addr
|
engine_id] = nixl_agent_meta.kv_caches_base_addr
|
||||||
rank_offset = self.tp_rank % tp_ratio * self.block_len \
|
kv_block_len = self.get_backend_aware_kv_block_len()
|
||||||
|
rank_offset = self.tp_rank % tp_ratio * kv_block_len \
|
||||||
if not (self.use_mla or is_kv_replicated) else 0
|
if not (self.use_mla or is_kv_replicated) else 0
|
||||||
# Register all remote blocks, but only the corresponding kv heads.
|
# Register all remote blocks, but only the corresponding kv heads.
|
||||||
for base_addr in nixl_agent_meta.kv_caches_base_addr:
|
for base_addr in nixl_agent_meta.kv_caches_base_addr:
|
||||||
@ -943,7 +962,16 @@ class NixlConnectorWorker:
|
|||||||
# self.block_len == remote_block_len//tp_ratio bytes.
|
# self.block_len == remote_block_len//tp_ratio bytes.
|
||||||
addr = base_addr + block_offset + rank_offset
|
addr = base_addr + block_offset + rank_offset
|
||||||
# (addr, len, device id)
|
# (addr, len, device id)
|
||||||
blocks_data.append((addr, self.block_len, remote_tp_rank))
|
blocks_data.append((addr, kv_block_len, remote_tp_rank))
|
||||||
|
|
||||||
|
if self._use_flashinfer:
|
||||||
|
# With FlashInfer index V separately to allow head splitting.
|
||||||
|
for block_id in range(nixl_agent_meta.num_blocks):
|
||||||
|
block_offset = block_id * nixl_agent_meta.block_len
|
||||||
|
addr = base_addr + block_offset + rank_offset
|
||||||
|
v_addr = addr + nixl_agent_meta.block_len // 2
|
||||||
|
blocks_data.append((v_addr, kv_block_len, remote_tp_rank))
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Created %s blocks for dst engine %s with remote rank %s and "
|
"Created %s blocks for dst engine %s with remote rank %s and "
|
||||||
"local rank %s", len(blocks_data), engine_id, remote_tp_rank,
|
"local rank %s", len(blocks_data), engine_id, remote_tp_rank,
|
||||||
@ -1249,6 +1277,22 @@ class NixlConnectorWorker:
|
|||||||
descs_ids.append(reg_id * num_blocks + block_id)
|
descs_ids.append(reg_id * num_blocks + block_id)
|
||||||
return descs_ids
|
return descs_ids
|
||||||
|
|
||||||
|
def get_backend_aware_kv_block_len(self):
|
||||||
|
"""
|
||||||
|
Get the block length for one K/V element (K and V have the same size).
|
||||||
|
|
||||||
|
For FA and other backends, this is equal to the length of the whole
|
||||||
|
block, as K and V are in separate regions.
|
||||||
|
For FlashInfer, this is half the length of the whole block, as K and V
|
||||||
|
share the same region.
|
||||||
|
"""
|
||||||
|
if self._use_flashinfer:
|
||||||
|
# For indexing only half (either just the K or V part).
|
||||||
|
block_len = self.block_len // 2
|
||||||
|
else:
|
||||||
|
block_len = self.block_len
|
||||||
|
return block_len
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]:
|
def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user