[P/D][Nixl] Make kv cache register compatible with hybrid memory allocator (#23079)

Signed-off-by: sfeng33 <4florafeng@gmail.com>
This commit is contained in:
Flora Feng 2025-08-21 22:30:48 -07:00 committed by GitHub
parent 17373dcd93
commit 53415653ff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 150 additions and 95 deletions

View File

@ -14,6 +14,7 @@ from unittest.mock import patch
import pytest
import ray
import torch
from vllm import LLM
from vllm.config import KVTransferConfig
@ -22,6 +23,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnectorWorker)
from vllm.forward_context import ForwardContext
from vllm.sampling_params import SamplingParams
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from .utils import create_request, create_scheduler, create_vllm_config
@ -98,7 +100,6 @@ class FakeNixlWrapper:
def set_cycles_before_xfer_done(self, cycles: int):
"""Set the number of cycles before a transfer is considered done."""
self._cycles_before_xfer_done = cycles
@contextlib.contextmanager
@ -562,3 +563,86 @@ def _run_abort_timeout_test(llm_kwargs: dict, timeout: int):
sampling_params)
# Request-0 times out and is cleared!
assert '0' not in req_to_blocks
def test_register_kv_caches(dist_init):
"""
Test that register_kv_caches() properly calls nixl_wrapper methods with
correct data.
This test verifies:
1. nixl_wrapper.get_reg_descs() is called with caches_data containing
tensor metadata
2. nixl_wrapper.get_xfer_descs() is called with blocks_data containing
block layout info
"""
vllm_config = create_vllm_config()
# Create test kv cache tensors using proper backend shape
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(num_blocks=2,
block_size=16,
num_kv_heads=4,
head_size=64)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
"layer2": shared_tensor,
}
# Store tensor info for validation
expected_tensor_size = shared_tensor[0].element_size(
) * shared_tensor[0].numel()
expected_base_addrs = [
shared_tensor[0].data_ptr(), shared_tensor[1].data_ptr(),
unique_tensor[0].data_ptr(), unique_tensor[1].data_ptr()
]
with patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper") as mock_nixl_wrapper, \
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event"), \
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"): # noqa: E501
# Create connector
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0)
# Get the mock instance
mock_wrapper_instance = mock_nixl_wrapper.return_value
connector.connector_worker.nixl_wrapper = mock_wrapper_instance
# Execute register_kv_caches
connector.register_kv_caches(kv_caches)
# Verify get_reg_descs was called with caches_data
assert mock_wrapper_instance.get_reg_descs.called
caches_data, _ = mock_wrapper_instance.get_reg_descs.call_args[0]
assert len(caches_data) == 4
for i, cache_entry in enumerate(caches_data):
base_addr, size, _tp_rank, _ = cache_entry
assert size == expected_tensor_size, \
f"Entry {i}: Expected tensor size {expected_tensor_size}, " \
f"got {size}"
assert base_addr == expected_base_addrs[i], \
f"Entry {i}: Expected base address {expected_base_addrs[i]}, " \
f"got {base_addr}"
# Verify get_xfer_descs was called with blocks_data
assert mock_wrapper_instance.get_xfer_descs.called
blocks_data, _ = mock_wrapper_instance.get_xfer_descs.call_args[0]
# Validate blocks_data structure and size
expected_blocks_count = 8
assert len(blocks_data) == expected_blocks_count, \
f"Expected {expected_blocks_count} blocks, " \
f"got {len(blocks_data)}"
expected_block_len = expected_tensor_size // 2
for i, block_entry in enumerate(blocks_data):
block_start_addr, block_len, tp_rank = block_entry
assert block_len == expected_block_len, \
f"Block entry {i}: Expected block len {expected_block_len}, " \
f"got {block_len}"

View File

@ -131,8 +131,8 @@ class KVConnectorBase_V1(ABC):
Initialize with the KV caches. Useful for pre-registering the
KV Caches in the KVConnector (e.g. for NIXL).
Args: kv_caches:
dictionary of layer names, kv cache
Args:
kv_caches: dictionary of layer names, kv cache
"""
return

View File

@ -686,9 +686,6 @@ class NixlConnectorWorker:
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data in nixl."""
_, first_kv_cache = next(iter(kv_caches.items()))
kv_elem_size = first_kv_cache.element_size()
if self.use_host_buffer:
self.initialize_host_xfer_buffer(kv_caches=kv_caches)
assert len(self.host_xfer_buffers) == len(kv_caches), (
@ -701,66 +698,16 @@ class NixlConnectorWorker:
"host_xfer_buffer should not be initialized when "
f"kv_buffer_device is {self.kv_buffer_device}")
# TODO(tms): Find a more robust way to detect and handle MLA
# NOTE (NickLucche) To move blocks efficiently with NIXL, the expected
# KV memory layout is HND, as opposed to the default NHD. Note that it
# will only affects the strides. For MLA instead, we make require no
# such thing and resort to the standard layout.
use_mla = len(first_kv_cache.shape) == 3
if self.device_type == "tpu":
assert not use_mla, f"{self.kv_buffer_device} does not support MLA."
assert self._use_pallas_v1, f"attn backend: {self.backend_name}"
# tpu (v1) kv shape per layer:
# (num_blocks, block_size, num_kv_heads * 2, head_size)
self.num_blocks = first_kv_cache.shape[0]
block_rank = 3 # [block_size, kv_heads, head_dim]
block_shape = first_kv_cache.shape[-block_rank:]
block_size, n_kv_heads_x_2, head_dim = block_shape
self.slot_size_bytes = kv_elem_size * n_kv_heads_x_2 * head_dim
elif self.device_type == "cuda":
assert use_mla == self.use_mla
# TODO (NickLucche) not compatible with hybrid allocator.
# Enforce check once it goes live, as a single kv layout
# is expected for xfers.
if use_mla:
# MLA case.
self.num_blocks = first_kv_cache.shape[0]
block_rank = 2 # [block_size, latent_dim]
block_shape = first_kv_cache.shape[-block_rank:]
block_size, kv_latent_dim = block_shape
self.slot_size_bytes = kv_elem_size * kv_latent_dim
else:
# [2 (k and v), num_blocks, ...]
if self._use_flashinfer:
# FlashInfer swaps 2<->num_blocks dimensions.
self.num_blocks = first_kv_cache.shape[0]
block_rank = 4 # [2, block_size, kv_heads, head_dim]
else:
self.num_blocks = first_kv_cache.shape[1]
block_rank = 3 # [block_size, kv_heads, head_dim]
block_shape = first_kv_cache.shape[-block_rank:]
block_size, n_kv_heads, head_dim = block_shape[-3:]
# head size in bytes.
self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim
assert block_size == self.block_size
else:
raise RuntimeError(
f"{self.device_type} ({self.backend_name}) is not supported.")
# TODO(tms): self.block_len needs to be per-layer for sliding window,
# hybrid attn, etc
# block size in bytes
self.block_len = kv_elem_size * math.prod(block_shape)
logger.info(
"Registering KV_Caches. use_mla: %s, kv_buffer_device: %s, "
"use_host_buffer: %s, num_blocks: %s, block_shape: %s, "
"per_layer_kv_cache_shape: %s", use_mla, self.kv_buffer_device,
self.use_host_buffer, self.num_blocks, block_shape,
first_kv_cache.shape)
self.dst_num_blocks[self.engine_id] = self.num_blocks
self.device_kv_caches = kv_caches
kv_caches_base_addr = []
"use_host_buffer: %s", self.use_mla, self.kv_buffer_device,
self.use_host_buffer)
caches_data = []
# With hybrid allocator, layers can share a kv cache tensor
seen_base_addresses = []
xfer_buffers = (self.host_xfer_buffers
if self.use_host_buffer else kv_caches)
# Note(tms): I modified this from the original region setup code.
# K and V are now in different regions. Advantage is that we can
@ -770,42 +717,35 @@ class NixlConnectorWorker:
# (roughly 8KB vs 5KB).
# Conversely for FlashInfer, K and V are transferred in the same tensor
# to better exploit the memory layout (ie num_blocks is the first dim).
for cache_or_caches in xfer_buffers.values():
# Normalize to always be a list of caches
cache_list = [cache_or_caches] if use_mla \
or self._use_pallas_v1 or self._use_flashinfer \
else cache_or_caches
split_k_and_v = not (self.use_mla or self._use_pallas_v1
or self._use_flashinfer)
tensor_size_bytes = None
for layer_name, cache_or_caches in xfer_buffers.items():
cache_list = cache_or_caches if split_k_and_v else [
cache_or_caches
]
for cache in cache_list:
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len
# NOTE: use tp_rank for device_id since multi-node TP
# is rarely used.
caches_data.append((base_addr, region_len, self.tp_rank, ""))
kv_caches_base_addr.append(base_addr)
self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
if base_addr in seen_base_addresses:
continue
seen_base_addresses.append(base_addr)
curr_tensor_size_bytes = cache.numel() * cache.element_size()
if tensor_size_bytes is None:
tensor_size_bytes = curr_tensor_size_bytes
self.num_blocks = cache.shape[0]
assert tensor_size_bytes == curr_tensor_size_bytes, \
"All kv cache tensors must have the same size"
caches_data.append(
(base_addr, tensor_size_bytes, self.tp_rank, ""))
self.kv_caches_base_addr[self.engine_id] = seen_base_addresses
self.num_regions = len(caches_data)
self.num_layers = len(xfer_buffers.keys())
# TODO(mgoin): remove this once we have hybrid memory allocator
# Optimization for models with local attention (Llama 4)
if self.vllm_config.model_config.hf_config.model_type == "llama4":
from transformers import Llama4TextConfig
assert isinstance(self.vllm_config.model_config.hf_text_config,
Llama4TextConfig)
llama4_config = self.vllm_config.model_config.hf_text_config
no_rope_layers = llama4_config.no_rope_layers
chunk_size = llama4_config.attention_chunk_size
chunk_block_size = math.ceil(chunk_size / self.block_size)
for layer_idx in range(self.num_layers):
# no_rope_layers[layer_idx] == 0 means NoPE (global)
# Any other value means RoPE (local chunked)
is_local_attention = no_rope_layers[layer_idx] != 0
block_window = chunk_block_size if is_local_attention else None
self.block_window_per_layer.append(block_window)
logger.debug("Llama 4 block window per layer mapping: %s",
self.block_window_per_layer)
assert len(self.block_window_per_layer) == self.num_layers
descs = self.nixl_wrapper.get_reg_descs(caches_data,
self.nixl_memory_type)
logger.debug("Registering descs: %s", caches_data)
@ -813,9 +753,20 @@ class NixlConnectorWorker:
logger.debug("Done registering descs")
self._registered_descs.append(descs)
assert tensor_size_bytes is not None
assert self.num_blocks != 0
assert tensor_size_bytes % self.num_blocks == 0
self.block_len = tensor_size_bytes // self.num_blocks
self.slot_size_bytes = self.block_len // self.block_size
if self._use_flashinfer:
assert self.slot_size_bytes % 2 == 0
self.slot_size_bytes /= 2
self.device_kv_caches = kv_caches
self.dst_num_blocks[self.engine_id] = self.num_blocks
# Register local/src descr for NIXL xfer.
blocks_data = []
for base_addr in self.kv_caches_base_addr[self.engine_id]:
for base_addr in seen_base_addresses:
# NOTE With heter-TP, more blocks are prepared than what are
# needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
# could create fewer, but then _get_block_descs_ids needs to
@ -836,6 +787,26 @@ class NixlConnectorWorker:
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
"NIXL_INIT_AGENT", descs)
# TODO(mgoin): Hybrid memory allocator is currently diabled for
# models with local attention (Llama 4). Can remove this once enabled.
if self.vllm_config.model_config.hf_config.model_type == "llama4":
from transformers import Llama4TextConfig
assert isinstance(self.vllm_config.model_config.hf_text_config,
Llama4TextConfig)
llama4_config = self.vllm_config.model_config.hf_text_config
no_rope_layers = llama4_config.no_rope_layers
chunk_size = llama4_config.attention_chunk_size
chunk_block_size = math.ceil(chunk_size / self.block_size)
for layer_idx in range(self.num_layers):
# no_rope_layers[layer_idx] == 0 means NoPE (global)
# Any other value means RoPE (local chunked)
is_local_attention = no_rope_layers[layer_idx] != 0
block_window = chunk_block_size if is_local_attention else None
self.block_window_per_layer.append(block_window)
logger.debug("Llama 4 block window per layer mapping: %s",
self.block_window_per_layer)
assert len(self.block_window_per_layer) == self.num_layers
# After KV Caches registered, listen for new connections.
metadata = NixlAgentMetadata(
engine_id=self.engine_id,