[NIXL] Generalize block-first backend layouts (FlashInfer-like) (#28282)

This commit is contained in:
Nicolò Lucchesi 2025-11-11 17:57:43 +01:00 committed by GitHub
parent f9a4087182
commit a7ef3eb0cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 52 additions and 12 deletions

View File

@ -1096,7 +1096,8 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
llm.llm_engine.engine_core.shutdown()
def test_register_kv_caches(dist_init):
@pytest.mark.parametrize("attn_backend", ["FLASH_ATTN", "TRITON_ATTN"])
def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
"""
Test that register_kv_caches() properly calls nixl_wrapper methods with
correct data.
@ -1108,10 +1109,22 @@ def test_register_kv_caches(dist_init):
block layout info
"""
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
vllm_config = create_vllm_config()
# Import the appropriate backend based on the parameter
if attn_backend == "FLASH_ATTN":
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
backend_cls = FlashAttentionBackend
else: # TRITON_ATTN
from vllm.v1.attention.backends.triton_attn import TritonAttentionBackend
backend_cls = TritonAttentionBackend
# Create test kv cache tensors using proper backend shape
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
kv_cache_shape = backend_cls.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)

View File

@ -21,6 +21,7 @@ import torch
import zmq
from vllm import envs
from vllm.attention import AttentionBackend
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import get_attn_backend
from vllm.config import VllmConfig
@ -669,6 +670,33 @@ class NixlConnectorWorker:
remote_tp_size: dict[EngineId, int]
is_mla: bool
total_num_kv_heads: int
attn_backend: type[AttentionBackend]
def __post_init__(self):
# Figure out whether the first dimension of the cache is K/V
# or num_blocks. This is used to register the memory regions correctly.
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
# Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D],
# we just mock num_blocks to 1 for the dimension check below.
self._is_kv_layout_blocks_first = (
len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1
)
attn_backend = AttentionBackendEnum[self.attn_backend.get_name()]
self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS
@property
def is_kv_layout_blocks_first(self) -> bool:
return self._is_kv_layout_blocks_first
@property
def split_k_and_v(self) -> bool:
# Whether to register regions for K and V separately (when present).
return not (
self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first
)
def tp_ratio(
self,
@ -876,9 +904,6 @@ class NixlConnectorWorker:
use_mla=self.use_mla,
)
self.backend_name = backend.get_name()
attn_backend = AttentionBackendEnum[self.backend_name]
self._use_flashinfer = attn_backend == AttentionBackendEnum.FLASHINFER
self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS
self.kv_cache_layout = get_kv_cache_layout()
self.host_buffer_kv_cache_layout = self.kv_cache_layout
logger.debug("Detected attention backend %s", self.backend_name)
@ -896,7 +921,9 @@ class NixlConnectorWorker:
remote_tp_size=self._tp_size, # shared state
is_mla=self.use_mla,
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backend=backend,
)
self._use_pallas = self.kv_topo._use_pallas
def _nixl_handshake(
self,
@ -1076,7 +1103,7 @@ class NixlConnectorWorker:
# (roughly 8KB vs 5KB).
# Conversely for FlashInfer, K and V are registered in the same region
# to better exploit the memory layout (ie num_blocks is the first dim).
split_k_and_v = not (self.use_mla or self._use_pallas or self._use_flashinfer)
split_k_and_v = self.kv_topo.split_k_and_v
tensor_size_bytes = None
# Enable different block lengths for different layers when MLA is used.
self.block_len_per_layer = list[int]()
@ -1141,7 +1168,7 @@ class NixlConnectorWorker:
self.device_kv_caches = kv_caches
self.dst_num_blocks[self.engine_id] = self.num_blocks
if self._use_flashinfer:
if self.kv_topo.is_kv_layout_blocks_first:
for i in range(len(self.slot_size_per_layer)):
assert self.slot_size_per_layer[i] % 2 == 0
self.slot_size_per_layer[i] //= 2
@ -1169,7 +1196,7 @@ class NixlConnectorWorker:
# (addr, len, device id)
blocks_data.append((addr, kv_block_len, self.device_id))
if self._use_flashinfer:
if self.kv_topo.is_kv_layout_blocks_first:
# Separate and interleave K/V regions to maintain the same
# descs ordering. This is needed for selecting contiguous heads
# when split across TP ranks.
@ -1331,7 +1358,7 @@ class NixlConnectorWorker:
# (addr, len, device id)
blocks_data.append((addr, kv_block_len, nixl_agent_meta.device_id))
if self._use_flashinfer:
if self.kv_topo.is_kv_layout_blocks_first:
# With FlashInfer index V separately to allow head splitting.
for block_id in range(nixl_agent_meta.num_blocks):
block_offset = block_id * nixl_agent_meta.block_lens[i]
@ -1414,7 +1441,7 @@ class NixlConnectorWorker:
remote_block_size = remote_block_len // (
self.slot_size_per_layer[0] * tp_ratio
)
if self._use_flashinfer:
if self.kv_topo.is_kv_layout_blocks_first:
# With flashinfer, KV are sent in the same message.
remote_block_size //= 2
@ -1494,7 +1521,7 @@ class NixlConnectorWorker:
- cache.index_copy_(0, indices, permuted_blocks) # copy permuted kv back
"""
split_k_and_v = not (self.use_mla or self._use_pallas or self._use_flashinfer)
split_k_and_v = self.kv_topo.split_k_and_v
inv_order = [0, 2, 1, 3]
sample_cache = list(self.device_kv_caches.values())[0][0]
target_shape = list(sample_cache.shape)
@ -1874,7 +1901,7 @@ class NixlConnectorWorker:
For FlashInfer, this is half the length of the whole block, as K and V
share the same region.
"""
if self._use_flashinfer:
if self.kv_topo.is_kv_layout_blocks_first:
# For indexing only half (either just the K or V part).
block_len = self.block_len_per_layer[layer_idx] // 2
else: