From 53415653ff24be03e7c90f5b42ef9cb3f72aad71 Mon Sep 17 00:00:00 2001 From: Flora Feng <4florafeng@gmail.com> Date: Thu, 21 Aug 2025 22:30:48 -0700 Subject: [PATCH] [P/D][Nixl] Make kv cache register compatible with hybrid memory allocator (#23079) Signed-off-by: sfeng33 <4florafeng@gmail.com> --- .../kv_connector/unit/test_nixl_connector.py | 86 +++++++++- .../kv_transfer/kv_connector/v1/base.py | 4 +- .../kv_connector/v1/nixl_connector.py | 155 +++++++----------- 3 files changed, 150 insertions(+), 95 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index e6859ea73827..040b44dc5d2c 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -14,6 +14,7 @@ from unittest.mock import patch import pytest import ray +import torch from vllm import LLM from vllm.config import KVTransferConfig @@ -22,6 +23,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( NixlConnectorWorker) from vllm.forward_context import ForwardContext from vllm.sampling_params import SamplingParams +from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend from .utils import create_request, create_scheduler, create_vllm_config @@ -98,7 +100,6 @@ class FakeNixlWrapper: def set_cycles_before_xfer_done(self, cycles: int): """Set the number of cycles before a transfer is considered done.""" - self._cycles_before_xfer_done = cycles @contextlib.contextmanager @@ -562,3 +563,86 @@ def _run_abort_timeout_test(llm_kwargs: dict, timeout: int): sampling_params) # Request-0 times out and is cleared! assert '0' not in req_to_blocks + + +def test_register_kv_caches(dist_init): + """ + Test that register_kv_caches() properly calls nixl_wrapper methods with + correct data. + + This test verifies: + 1. nixl_wrapper.get_reg_descs() is called with caches_data containing + tensor metadata + 2. nixl_wrapper.get_xfer_descs() is called with blocks_data containing + block layout info + """ + + vllm_config = create_vllm_config() + + # Create test kv cache tensors using proper backend shape + kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(num_blocks=2, + block_size=16, + num_kv_heads=4, + head_size=64) + shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) + kv_caches = { + "layer0": shared_tensor, + "layer1": unique_tensor, + "layer2": shared_tensor, + } + + # Store tensor info for validation + expected_tensor_size = shared_tensor[0].element_size( + ) * shared_tensor[0].numel() + expected_base_addrs = [ + shared_tensor[0].data_ptr(), shared_tensor[1].data_ptr(), + unique_tensor[0].data_ptr(), unique_tensor[1].data_ptr() + ] + + with patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper") as mock_nixl_wrapper, \ + patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event"), \ + patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"): # noqa: E501 + + # Create connector + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, connector.engine_id, hand_shake_latency=0) + + # Get the mock instance + mock_wrapper_instance = mock_nixl_wrapper.return_value + connector.connector_worker.nixl_wrapper = mock_wrapper_instance + + # Execute register_kv_caches + connector.register_kv_caches(kv_caches) + + # Verify get_reg_descs was called with caches_data + assert mock_wrapper_instance.get_reg_descs.called + caches_data, _ = mock_wrapper_instance.get_reg_descs.call_args[0] + assert len(caches_data) == 4 + + for i, cache_entry in enumerate(caches_data): + base_addr, size, _tp_rank, _ = cache_entry + assert size == expected_tensor_size, \ + f"Entry {i}: Expected tensor size {expected_tensor_size}, " \ + f"got {size}" + assert base_addr == expected_base_addrs[i], \ + f"Entry {i}: Expected base address {expected_base_addrs[i]}, " \ + f"got {base_addr}" + + # Verify get_xfer_descs was called with blocks_data + assert mock_wrapper_instance.get_xfer_descs.called + blocks_data, _ = mock_wrapper_instance.get_xfer_descs.call_args[0] + + # Validate blocks_data structure and size + expected_blocks_count = 8 + assert len(blocks_data) == expected_blocks_count, \ + f"Expected {expected_blocks_count} blocks, " \ + f"got {len(blocks_data)}" + + expected_block_len = expected_tensor_size // 2 + for i, block_entry in enumerate(blocks_data): + block_start_addr, block_len, tp_rank = block_entry + assert block_len == expected_block_len, \ + f"Block entry {i}: Expected block len {expected_block_len}, " \ + f"got {block_len}" diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 07fcdecac627..5601ee74be11 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -131,8 +131,8 @@ class KVConnectorBase_V1(ABC): Initialize with the KV caches. Useful for pre-registering the KV Caches in the KVConnector (e.g. for NIXL). - Args: kv_caches: - dictionary of layer names, kv cache + Args: + kv_caches: dictionary of layer names, kv cache """ return diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 4f51229ffbd2..6608d2a4a9e0 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -686,9 +686,6 @@ class NixlConnectorWorker: def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in nixl.""" - _, first_kv_cache = next(iter(kv_caches.items())) - kv_elem_size = first_kv_cache.element_size() - if self.use_host_buffer: self.initialize_host_xfer_buffer(kv_caches=kv_caches) assert len(self.host_xfer_buffers) == len(kv_caches), ( @@ -701,66 +698,16 @@ class NixlConnectorWorker: "host_xfer_buffer should not be initialized when " f"kv_buffer_device is {self.kv_buffer_device}") - # TODO(tms): Find a more robust way to detect and handle MLA - # NOTE (NickLucche) To move blocks efficiently with NIXL, the expected - # KV memory layout is HND, as opposed to the default NHD. Note that it - # will only affects the strides. For MLA instead, we make require no - # such thing and resort to the standard layout. - use_mla = len(first_kv_cache.shape) == 3 - if self.device_type == "tpu": - assert not use_mla, f"{self.kv_buffer_device} does not support MLA." - assert self._use_pallas_v1, f"attn backend: {self.backend_name}" - # tpu (v1) kv shape per layer: - # (num_blocks, block_size, num_kv_heads * 2, head_size) - self.num_blocks = first_kv_cache.shape[0] - block_rank = 3 # [block_size, kv_heads, head_dim] - block_shape = first_kv_cache.shape[-block_rank:] - block_size, n_kv_heads_x_2, head_dim = block_shape - self.slot_size_bytes = kv_elem_size * n_kv_heads_x_2 * head_dim - elif self.device_type == "cuda": - assert use_mla == self.use_mla - # TODO (NickLucche) not compatible with hybrid allocator. - # Enforce check once it goes live, as a single kv layout - # is expected for xfers. - if use_mla: - # MLA case. - self.num_blocks = first_kv_cache.shape[0] - block_rank = 2 # [block_size, latent_dim] - block_shape = first_kv_cache.shape[-block_rank:] - block_size, kv_latent_dim = block_shape - self.slot_size_bytes = kv_elem_size * kv_latent_dim - else: - # [2 (k and v), num_blocks, ...] - if self._use_flashinfer: - # FlashInfer swaps 2<->num_blocks dimensions. - self.num_blocks = first_kv_cache.shape[0] - block_rank = 4 # [2, block_size, kv_heads, head_dim] - else: - self.num_blocks = first_kv_cache.shape[1] - block_rank = 3 # [block_size, kv_heads, head_dim] - block_shape = first_kv_cache.shape[-block_rank:] - block_size, n_kv_heads, head_dim = block_shape[-3:] - # head size in bytes. - self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim - assert block_size == self.block_size - else: - raise RuntimeError( - f"{self.device_type} ({self.backend_name}) is not supported.") - - # TODO(tms): self.block_len needs to be per-layer for sliding window, - # hybrid attn, etc - # block size in bytes - self.block_len = kv_elem_size * math.prod(block_shape) logger.info( "Registering KV_Caches. use_mla: %s, kv_buffer_device: %s, " - "use_host_buffer: %s, num_blocks: %s, block_shape: %s, " - "per_layer_kv_cache_shape: %s", use_mla, self.kv_buffer_device, - self.use_host_buffer, self.num_blocks, block_shape, - first_kv_cache.shape) - self.dst_num_blocks[self.engine_id] = self.num_blocks - self.device_kv_caches = kv_caches - kv_caches_base_addr = [] + "use_host_buffer: %s", self.use_mla, self.kv_buffer_device, + self.use_host_buffer) + caches_data = [] + # With hybrid allocator, layers can share a kv cache tensor + seen_base_addresses = [] + xfer_buffers = (self.host_xfer_buffers + if self.use_host_buffer else kv_caches) # Note(tms): I modified this from the original region setup code. # K and V are now in different regions. Advantage is that we can @@ -770,42 +717,35 @@ class NixlConnectorWorker: # (roughly 8KB vs 5KB). # Conversely for FlashInfer, K and V are transferred in the same tensor # to better exploit the memory layout (ie num_blocks is the first dim). - for cache_or_caches in xfer_buffers.values(): - # Normalize to always be a list of caches - cache_list = [cache_or_caches] if use_mla \ - or self._use_pallas_v1 or self._use_flashinfer \ - else cache_or_caches + split_k_and_v = not (self.use_mla or self._use_pallas_v1 + or self._use_flashinfer) + tensor_size_bytes = None + for layer_name, cache_or_caches in xfer_buffers.items(): + cache_list = cache_or_caches if split_k_and_v else [ + cache_or_caches + ] + for cache in cache_list: base_addr = cache.data_ptr() - region_len = self.num_blocks * self.block_len - # NOTE: use tp_rank for device_id since multi-node TP - # is rarely used. - caches_data.append((base_addr, region_len, self.tp_rank, "")) - kv_caches_base_addr.append(base_addr) - self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr + if base_addr in seen_base_addresses: + continue + + seen_base_addresses.append(base_addr) + curr_tensor_size_bytes = cache.numel() * cache.element_size() + + if tensor_size_bytes is None: + tensor_size_bytes = curr_tensor_size_bytes + self.num_blocks = cache.shape[0] + + assert tensor_size_bytes == curr_tensor_size_bytes, \ + "All kv cache tensors must have the same size" + caches_data.append( + (base_addr, tensor_size_bytes, self.tp_rank, "")) + + self.kv_caches_base_addr[self.engine_id] = seen_base_addresses self.num_regions = len(caches_data) self.num_layers = len(xfer_buffers.keys()) - # TODO(mgoin): remove this once we have hybrid memory allocator - # Optimization for models with local attention (Llama 4) - if self.vllm_config.model_config.hf_config.model_type == "llama4": - from transformers import Llama4TextConfig - assert isinstance(self.vllm_config.model_config.hf_text_config, - Llama4TextConfig) - llama4_config = self.vllm_config.model_config.hf_text_config - no_rope_layers = llama4_config.no_rope_layers - chunk_size = llama4_config.attention_chunk_size - chunk_block_size = math.ceil(chunk_size / self.block_size) - for layer_idx in range(self.num_layers): - # no_rope_layers[layer_idx] == 0 means NoPE (global) - # Any other value means RoPE (local chunked) - is_local_attention = no_rope_layers[layer_idx] != 0 - block_window = chunk_block_size if is_local_attention else None - self.block_window_per_layer.append(block_window) - logger.debug("Llama 4 block window per layer mapping: %s", - self.block_window_per_layer) - assert len(self.block_window_per_layer) == self.num_layers - descs = self.nixl_wrapper.get_reg_descs(caches_data, self.nixl_memory_type) logger.debug("Registering descs: %s", caches_data) @@ -813,9 +753,20 @@ class NixlConnectorWorker: logger.debug("Done registering descs") self._registered_descs.append(descs) + assert tensor_size_bytes is not None + assert self.num_blocks != 0 + assert tensor_size_bytes % self.num_blocks == 0 + self.block_len = tensor_size_bytes // self.num_blocks + self.slot_size_bytes = self.block_len // self.block_size + if self._use_flashinfer: + assert self.slot_size_bytes % 2 == 0 + self.slot_size_bytes /= 2 + self.device_kv_caches = kv_caches + self.dst_num_blocks[self.engine_id] = self.num_blocks + # Register local/src descr for NIXL xfer. blocks_data = [] - for base_addr in self.kv_caches_base_addr[self.engine_id]: + for base_addr in seen_base_addresses: # NOTE With heter-TP, more blocks are prepared than what are # needed as self.num_blocks >= nixl_agent_meta.num_blocks. We # could create fewer, but then _get_block_descs_ids needs to @@ -836,6 +787,26 @@ class NixlConnectorWorker: self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( "NIXL_INIT_AGENT", descs) + # TODO(mgoin): Hybrid memory allocator is currently diabled for + # models with local attention (Llama 4). Can remove this once enabled. + if self.vllm_config.model_config.hf_config.model_type == "llama4": + from transformers import Llama4TextConfig + assert isinstance(self.vllm_config.model_config.hf_text_config, + Llama4TextConfig) + llama4_config = self.vllm_config.model_config.hf_text_config + no_rope_layers = llama4_config.no_rope_layers + chunk_size = llama4_config.attention_chunk_size + chunk_block_size = math.ceil(chunk_size / self.block_size) + for layer_idx in range(self.num_layers): + # no_rope_layers[layer_idx] == 0 means NoPE (global) + # Any other value means RoPE (local chunked) + is_local_attention = no_rope_layers[layer_idx] != 0 + block_window = chunk_block_size if is_local_attention else None + self.block_window_per_layer.append(block_window) + logger.debug("Llama 4 block window per layer mapping: %s", + self.block_window_per_layer) + assert len(self.block_window_per_layer) == self.num_layers + # After KV Caches registered, listen for new connections. metadata = NixlAgentMetadata( engine_id=self.engine_id,