mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 22:55:51 +08:00
[NIXL] Generalize block-first backend layouts (FlashInfer-like) (#28282)
This commit is contained in:
parent
f9a4087182
commit
a7ef3eb0cd
@ -1096,7 +1096,8 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
|
||||
llm.llm_engine.engine_core.shutdown()
|
||||
|
||||
|
||||
def test_register_kv_caches(dist_init):
|
||||
@pytest.mark.parametrize("attn_backend", ["FLASH_ATTN", "TRITON_ATTN"])
|
||||
def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
|
||||
"""
|
||||
Test that register_kv_caches() properly calls nixl_wrapper methods with
|
||||
correct data.
|
||||
@ -1108,10 +1109,22 @@ def test_register_kv_caches(dist_init):
|
||||
block layout info
|
||||
"""
|
||||
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
|
||||
# Import the appropriate backend based on the parameter
|
||||
if attn_backend == "FLASH_ATTN":
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
||||
|
||||
backend_cls = FlashAttentionBackend
|
||||
else: # TRITON_ATTN
|
||||
from vllm.v1.attention.backends.triton_attn import TritonAttentionBackend
|
||||
|
||||
backend_cls = TritonAttentionBackend
|
||||
|
||||
# Create test kv cache tensors using proper backend shape
|
||||
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
|
||||
kv_cache_shape = backend_cls.get_kv_cache_shape(
|
||||
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
|
||||
)
|
||||
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
|
||||
|
||||
@ -21,6 +21,7 @@ import torch
|
||||
import zmq
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention import AttentionBackend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import VllmConfig
|
||||
@ -669,6 +670,33 @@ class NixlConnectorWorker:
|
||||
remote_tp_size: dict[EngineId, int]
|
||||
is_mla: bool
|
||||
total_num_kv_heads: int
|
||||
attn_backend: type[AttentionBackend]
|
||||
|
||||
def __post_init__(self):
|
||||
# Figure out whether the first dimension of the cache is K/V
|
||||
# or num_blocks. This is used to register the memory regions correctly.
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
|
||||
)
|
||||
# Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D],
|
||||
# we just mock num_blocks to 1 for the dimension check below.
|
||||
self._is_kv_layout_blocks_first = (
|
||||
len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1
|
||||
)
|
||||
|
||||
attn_backend = AttentionBackendEnum[self.attn_backend.get_name()]
|
||||
self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS
|
||||
|
||||
@property
|
||||
def is_kv_layout_blocks_first(self) -> bool:
|
||||
return self._is_kv_layout_blocks_first
|
||||
|
||||
@property
|
||||
def split_k_and_v(self) -> bool:
|
||||
# Whether to register regions for K and V separately (when present).
|
||||
return not (
|
||||
self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first
|
||||
)
|
||||
|
||||
def tp_ratio(
|
||||
self,
|
||||
@ -876,9 +904,6 @@ class NixlConnectorWorker:
|
||||
use_mla=self.use_mla,
|
||||
)
|
||||
self.backend_name = backend.get_name()
|
||||
attn_backend = AttentionBackendEnum[self.backend_name]
|
||||
self._use_flashinfer = attn_backend == AttentionBackendEnum.FLASHINFER
|
||||
self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS
|
||||
self.kv_cache_layout = get_kv_cache_layout()
|
||||
self.host_buffer_kv_cache_layout = self.kv_cache_layout
|
||||
logger.debug("Detected attention backend %s", self.backend_name)
|
||||
@ -896,7 +921,9 @@ class NixlConnectorWorker:
|
||||
remote_tp_size=self._tp_size, # shared state
|
||||
is_mla=self.use_mla,
|
||||
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
|
||||
attn_backend=backend,
|
||||
)
|
||||
self._use_pallas = self.kv_topo._use_pallas
|
||||
|
||||
def _nixl_handshake(
|
||||
self,
|
||||
@ -1076,7 +1103,7 @@ class NixlConnectorWorker:
|
||||
# (roughly 8KB vs 5KB).
|
||||
# Conversely for FlashInfer, K and V are registered in the same region
|
||||
# to better exploit the memory layout (ie num_blocks is the first dim).
|
||||
split_k_and_v = not (self.use_mla or self._use_pallas or self._use_flashinfer)
|
||||
split_k_and_v = self.kv_topo.split_k_and_v
|
||||
tensor_size_bytes = None
|
||||
# Enable different block lengths for different layers when MLA is used.
|
||||
self.block_len_per_layer = list[int]()
|
||||
@ -1141,7 +1168,7 @@ class NixlConnectorWorker:
|
||||
|
||||
self.device_kv_caches = kv_caches
|
||||
self.dst_num_blocks[self.engine_id] = self.num_blocks
|
||||
if self._use_flashinfer:
|
||||
if self.kv_topo.is_kv_layout_blocks_first:
|
||||
for i in range(len(self.slot_size_per_layer)):
|
||||
assert self.slot_size_per_layer[i] % 2 == 0
|
||||
self.slot_size_per_layer[i] //= 2
|
||||
@ -1169,7 +1196,7 @@ class NixlConnectorWorker:
|
||||
# (addr, len, device id)
|
||||
blocks_data.append((addr, kv_block_len, self.device_id))
|
||||
|
||||
if self._use_flashinfer:
|
||||
if self.kv_topo.is_kv_layout_blocks_first:
|
||||
# Separate and interleave K/V regions to maintain the same
|
||||
# descs ordering. This is needed for selecting contiguous heads
|
||||
# when split across TP ranks.
|
||||
@ -1331,7 +1358,7 @@ class NixlConnectorWorker:
|
||||
# (addr, len, device id)
|
||||
blocks_data.append((addr, kv_block_len, nixl_agent_meta.device_id))
|
||||
|
||||
if self._use_flashinfer:
|
||||
if self.kv_topo.is_kv_layout_blocks_first:
|
||||
# With FlashInfer index V separately to allow head splitting.
|
||||
for block_id in range(nixl_agent_meta.num_blocks):
|
||||
block_offset = block_id * nixl_agent_meta.block_lens[i]
|
||||
@ -1414,7 +1441,7 @@ class NixlConnectorWorker:
|
||||
remote_block_size = remote_block_len // (
|
||||
self.slot_size_per_layer[0] * tp_ratio
|
||||
)
|
||||
if self._use_flashinfer:
|
||||
if self.kv_topo.is_kv_layout_blocks_first:
|
||||
# With flashinfer, KV are sent in the same message.
|
||||
remote_block_size //= 2
|
||||
|
||||
@ -1494,7 +1521,7 @@ class NixlConnectorWorker:
|
||||
- cache.index_copy_(0, indices, permuted_blocks) # copy permuted kv back
|
||||
|
||||
"""
|
||||
split_k_and_v = not (self.use_mla or self._use_pallas or self._use_flashinfer)
|
||||
split_k_and_v = self.kv_topo.split_k_and_v
|
||||
inv_order = [0, 2, 1, 3]
|
||||
sample_cache = list(self.device_kv_caches.values())[0][0]
|
||||
target_shape = list(sample_cache.shape)
|
||||
@ -1874,7 +1901,7 @@ class NixlConnectorWorker:
|
||||
For FlashInfer, this is half the length of the whole block, as K and V
|
||||
share the same region.
|
||||
"""
|
||||
if self._use_flashinfer:
|
||||
if self.kv_topo.is_kv_layout_blocks_first:
|
||||
# For indexing only half (either just the K or V part).
|
||||
block_len = self.block_len_per_layer[layer_idx] // 2
|
||||
else:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user