mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 21:45:44 +08:00
[P/D][Nixl] Make kv cache register compatible with hybrid memory allocator (#23079)
Signed-off-by: sfeng33 <4florafeng@gmail.com>
This commit is contained in:
parent
17373dcd93
commit
53415653ff
@ -14,6 +14,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import ray
|
||||
import torch
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.config import KVTransferConfig
|
||||
@ -22,6 +23,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
||||
NixlConnectorWorker)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
||||
|
||||
from .utils import create_request, create_scheduler, create_vllm_config
|
||||
|
||||
@ -98,7 +100,6 @@ class FakeNixlWrapper:
|
||||
|
||||
def set_cycles_before_xfer_done(self, cycles: int):
|
||||
"""Set the number of cycles before a transfer is considered done."""
|
||||
self._cycles_before_xfer_done = cycles
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@ -562,3 +563,86 @@ def _run_abort_timeout_test(llm_kwargs: dict, timeout: int):
|
||||
sampling_params)
|
||||
# Request-0 times out and is cleared!
|
||||
assert '0' not in req_to_blocks
|
||||
|
||||
|
||||
def test_register_kv_caches(dist_init):
|
||||
"""
|
||||
Test that register_kv_caches() properly calls nixl_wrapper methods with
|
||||
correct data.
|
||||
|
||||
This test verifies:
|
||||
1. nixl_wrapper.get_reg_descs() is called with caches_data containing
|
||||
tensor metadata
|
||||
2. nixl_wrapper.get_xfer_descs() is called with blocks_data containing
|
||||
block layout info
|
||||
"""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
|
||||
# Create test kv cache tensors using proper backend shape
|
||||
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(num_blocks=2,
|
||||
block_size=16,
|
||||
num_kv_heads=4,
|
||||
head_size=64)
|
||||
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
|
||||
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
|
||||
kv_caches = {
|
||||
"layer0": shared_tensor,
|
||||
"layer1": unique_tensor,
|
||||
"layer2": shared_tensor,
|
||||
}
|
||||
|
||||
# Store tensor info for validation
|
||||
expected_tensor_size = shared_tensor[0].element_size(
|
||||
) * shared_tensor[0].numel()
|
||||
expected_base_addrs = [
|
||||
shared_tensor[0].data_ptr(), shared_tensor[1].data_ptr(),
|
||||
unique_tensor[0].data_ptr(), unique_tensor[1].data_ptr()
|
||||
]
|
||||
|
||||
with patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper") as mock_nixl_wrapper, \
|
||||
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event"), \
|
||||
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"): # noqa: E501
|
||||
|
||||
# Create connector
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(
|
||||
vllm_config, connector.engine_id, hand_shake_latency=0)
|
||||
|
||||
# Get the mock instance
|
||||
mock_wrapper_instance = mock_nixl_wrapper.return_value
|
||||
connector.connector_worker.nixl_wrapper = mock_wrapper_instance
|
||||
|
||||
# Execute register_kv_caches
|
||||
connector.register_kv_caches(kv_caches)
|
||||
|
||||
# Verify get_reg_descs was called with caches_data
|
||||
assert mock_wrapper_instance.get_reg_descs.called
|
||||
caches_data, _ = mock_wrapper_instance.get_reg_descs.call_args[0]
|
||||
assert len(caches_data) == 4
|
||||
|
||||
for i, cache_entry in enumerate(caches_data):
|
||||
base_addr, size, _tp_rank, _ = cache_entry
|
||||
assert size == expected_tensor_size, \
|
||||
f"Entry {i}: Expected tensor size {expected_tensor_size}, " \
|
||||
f"got {size}"
|
||||
assert base_addr == expected_base_addrs[i], \
|
||||
f"Entry {i}: Expected base address {expected_base_addrs[i]}, " \
|
||||
f"got {base_addr}"
|
||||
|
||||
# Verify get_xfer_descs was called with blocks_data
|
||||
assert mock_wrapper_instance.get_xfer_descs.called
|
||||
blocks_data, _ = mock_wrapper_instance.get_xfer_descs.call_args[0]
|
||||
|
||||
# Validate blocks_data structure and size
|
||||
expected_blocks_count = 8
|
||||
assert len(blocks_data) == expected_blocks_count, \
|
||||
f"Expected {expected_blocks_count} blocks, " \
|
||||
f"got {len(blocks_data)}"
|
||||
|
||||
expected_block_len = expected_tensor_size // 2
|
||||
for i, block_entry in enumerate(blocks_data):
|
||||
block_start_addr, block_len, tp_rank = block_entry
|
||||
assert block_len == expected_block_len, \
|
||||
f"Block entry {i}: Expected block len {expected_block_len}, " \
|
||||
f"got {block_len}"
|
||||
|
||||
@ -131,8 +131,8 @@ class KVConnectorBase_V1(ABC):
|
||||
Initialize with the KV caches. Useful for pre-registering the
|
||||
KV Caches in the KVConnector (e.g. for NIXL).
|
||||
|
||||
Args: kv_caches:
|
||||
dictionary of layer names, kv cache
|
||||
Args:
|
||||
kv_caches: dictionary of layer names, kv cache
|
||||
"""
|
||||
return
|
||||
|
||||
|
||||
@ -686,9 +686,6 @@ class NixlConnectorWorker:
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
"""Register the KV Cache data in nixl."""
|
||||
|
||||
_, first_kv_cache = next(iter(kv_caches.items()))
|
||||
kv_elem_size = first_kv_cache.element_size()
|
||||
|
||||
if self.use_host_buffer:
|
||||
self.initialize_host_xfer_buffer(kv_caches=kv_caches)
|
||||
assert len(self.host_xfer_buffers) == len(kv_caches), (
|
||||
@ -701,66 +698,16 @@ class NixlConnectorWorker:
|
||||
"host_xfer_buffer should not be initialized when "
|
||||
f"kv_buffer_device is {self.kv_buffer_device}")
|
||||
|
||||
# TODO(tms): Find a more robust way to detect and handle MLA
|
||||
# NOTE (NickLucche) To move blocks efficiently with NIXL, the expected
|
||||
# KV memory layout is HND, as opposed to the default NHD. Note that it
|
||||
# will only affects the strides. For MLA instead, we make require no
|
||||
# such thing and resort to the standard layout.
|
||||
use_mla = len(first_kv_cache.shape) == 3
|
||||
if self.device_type == "tpu":
|
||||
assert not use_mla, f"{self.kv_buffer_device} does not support MLA."
|
||||
assert self._use_pallas_v1, f"attn backend: {self.backend_name}"
|
||||
# tpu (v1) kv shape per layer:
|
||||
# (num_blocks, block_size, num_kv_heads * 2, head_size)
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
block_rank = 3 # [block_size, kv_heads, head_dim]
|
||||
block_shape = first_kv_cache.shape[-block_rank:]
|
||||
block_size, n_kv_heads_x_2, head_dim = block_shape
|
||||
self.slot_size_bytes = kv_elem_size * n_kv_heads_x_2 * head_dim
|
||||
elif self.device_type == "cuda":
|
||||
assert use_mla == self.use_mla
|
||||
# TODO (NickLucche) not compatible with hybrid allocator.
|
||||
# Enforce check once it goes live, as a single kv layout
|
||||
# is expected for xfers.
|
||||
if use_mla:
|
||||
# MLA case.
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
block_rank = 2 # [block_size, latent_dim]
|
||||
block_shape = first_kv_cache.shape[-block_rank:]
|
||||
block_size, kv_latent_dim = block_shape
|
||||
self.slot_size_bytes = kv_elem_size * kv_latent_dim
|
||||
else:
|
||||
# [2 (k and v), num_blocks, ...]
|
||||
if self._use_flashinfer:
|
||||
# FlashInfer swaps 2<->num_blocks dimensions.
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
block_rank = 4 # [2, block_size, kv_heads, head_dim]
|
||||
else:
|
||||
self.num_blocks = first_kv_cache.shape[1]
|
||||
block_rank = 3 # [block_size, kv_heads, head_dim]
|
||||
block_shape = first_kv_cache.shape[-block_rank:]
|
||||
block_size, n_kv_heads, head_dim = block_shape[-3:]
|
||||
# head size in bytes.
|
||||
self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim
|
||||
assert block_size == self.block_size
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"{self.device_type} ({self.backend_name}) is not supported.")
|
||||
|
||||
# TODO(tms): self.block_len needs to be per-layer for sliding window,
|
||||
# hybrid attn, etc
|
||||
# block size in bytes
|
||||
self.block_len = kv_elem_size * math.prod(block_shape)
|
||||
logger.info(
|
||||
"Registering KV_Caches. use_mla: %s, kv_buffer_device: %s, "
|
||||
"use_host_buffer: %s, num_blocks: %s, block_shape: %s, "
|
||||
"per_layer_kv_cache_shape: %s", use_mla, self.kv_buffer_device,
|
||||
self.use_host_buffer, self.num_blocks, block_shape,
|
||||
first_kv_cache.shape)
|
||||
self.dst_num_blocks[self.engine_id] = self.num_blocks
|
||||
self.device_kv_caches = kv_caches
|
||||
kv_caches_base_addr = []
|
||||
"use_host_buffer: %s", self.use_mla, self.kv_buffer_device,
|
||||
self.use_host_buffer)
|
||||
|
||||
caches_data = []
|
||||
# With hybrid allocator, layers can share a kv cache tensor
|
||||
seen_base_addresses = []
|
||||
xfer_buffers = (self.host_xfer_buffers
|
||||
if self.use_host_buffer else kv_caches)
|
||||
|
||||
# Note(tms): I modified this from the original region setup code.
|
||||
# K and V are now in different regions. Advantage is that we can
|
||||
@ -770,42 +717,35 @@ class NixlConnectorWorker:
|
||||
# (roughly 8KB vs 5KB).
|
||||
# Conversely for FlashInfer, K and V are transferred in the same tensor
|
||||
# to better exploit the memory layout (ie num_blocks is the first dim).
|
||||
for cache_or_caches in xfer_buffers.values():
|
||||
# Normalize to always be a list of caches
|
||||
cache_list = [cache_or_caches] if use_mla \
|
||||
or self._use_pallas_v1 or self._use_flashinfer \
|
||||
else cache_or_caches
|
||||
split_k_and_v = not (self.use_mla or self._use_pallas_v1
|
||||
or self._use_flashinfer)
|
||||
tensor_size_bytes = None
|
||||
for layer_name, cache_or_caches in xfer_buffers.items():
|
||||
cache_list = cache_or_caches if split_k_and_v else [
|
||||
cache_or_caches
|
||||
]
|
||||
|
||||
for cache in cache_list:
|
||||
base_addr = cache.data_ptr()
|
||||
region_len = self.num_blocks * self.block_len
|
||||
# NOTE: use tp_rank for device_id since multi-node TP
|
||||
# is rarely used.
|
||||
caches_data.append((base_addr, region_len, self.tp_rank, ""))
|
||||
kv_caches_base_addr.append(base_addr)
|
||||
self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
|
||||
if base_addr in seen_base_addresses:
|
||||
continue
|
||||
|
||||
seen_base_addresses.append(base_addr)
|
||||
curr_tensor_size_bytes = cache.numel() * cache.element_size()
|
||||
|
||||
if tensor_size_bytes is None:
|
||||
tensor_size_bytes = curr_tensor_size_bytes
|
||||
self.num_blocks = cache.shape[0]
|
||||
|
||||
assert tensor_size_bytes == curr_tensor_size_bytes, \
|
||||
"All kv cache tensors must have the same size"
|
||||
caches_data.append(
|
||||
(base_addr, tensor_size_bytes, self.tp_rank, ""))
|
||||
|
||||
self.kv_caches_base_addr[self.engine_id] = seen_base_addresses
|
||||
self.num_regions = len(caches_data)
|
||||
self.num_layers = len(xfer_buffers.keys())
|
||||
|
||||
# TODO(mgoin): remove this once we have hybrid memory allocator
|
||||
# Optimization for models with local attention (Llama 4)
|
||||
if self.vllm_config.model_config.hf_config.model_type == "llama4":
|
||||
from transformers import Llama4TextConfig
|
||||
assert isinstance(self.vllm_config.model_config.hf_text_config,
|
||||
Llama4TextConfig)
|
||||
llama4_config = self.vllm_config.model_config.hf_text_config
|
||||
no_rope_layers = llama4_config.no_rope_layers
|
||||
chunk_size = llama4_config.attention_chunk_size
|
||||
chunk_block_size = math.ceil(chunk_size / self.block_size)
|
||||
for layer_idx in range(self.num_layers):
|
||||
# no_rope_layers[layer_idx] == 0 means NoPE (global)
|
||||
# Any other value means RoPE (local chunked)
|
||||
is_local_attention = no_rope_layers[layer_idx] != 0
|
||||
block_window = chunk_block_size if is_local_attention else None
|
||||
self.block_window_per_layer.append(block_window)
|
||||
logger.debug("Llama 4 block window per layer mapping: %s",
|
||||
self.block_window_per_layer)
|
||||
assert len(self.block_window_per_layer) == self.num_layers
|
||||
|
||||
descs = self.nixl_wrapper.get_reg_descs(caches_data,
|
||||
self.nixl_memory_type)
|
||||
logger.debug("Registering descs: %s", caches_data)
|
||||
@ -813,9 +753,20 @@ class NixlConnectorWorker:
|
||||
logger.debug("Done registering descs")
|
||||
self._registered_descs.append(descs)
|
||||
|
||||
assert tensor_size_bytes is not None
|
||||
assert self.num_blocks != 0
|
||||
assert tensor_size_bytes % self.num_blocks == 0
|
||||
self.block_len = tensor_size_bytes // self.num_blocks
|
||||
self.slot_size_bytes = self.block_len // self.block_size
|
||||
if self._use_flashinfer:
|
||||
assert self.slot_size_bytes % 2 == 0
|
||||
self.slot_size_bytes /= 2
|
||||
self.device_kv_caches = kv_caches
|
||||
self.dst_num_blocks[self.engine_id] = self.num_blocks
|
||||
|
||||
# Register local/src descr for NIXL xfer.
|
||||
blocks_data = []
|
||||
for base_addr in self.kv_caches_base_addr[self.engine_id]:
|
||||
for base_addr in seen_base_addresses:
|
||||
# NOTE With heter-TP, more blocks are prepared than what are
|
||||
# needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
|
||||
# could create fewer, but then _get_block_descs_ids needs to
|
||||
@ -836,6 +787,26 @@ class NixlConnectorWorker:
|
||||
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
|
||||
"NIXL_INIT_AGENT", descs)
|
||||
|
||||
# TODO(mgoin): Hybrid memory allocator is currently diabled for
|
||||
# models with local attention (Llama 4). Can remove this once enabled.
|
||||
if self.vllm_config.model_config.hf_config.model_type == "llama4":
|
||||
from transformers import Llama4TextConfig
|
||||
assert isinstance(self.vllm_config.model_config.hf_text_config,
|
||||
Llama4TextConfig)
|
||||
llama4_config = self.vllm_config.model_config.hf_text_config
|
||||
no_rope_layers = llama4_config.no_rope_layers
|
||||
chunk_size = llama4_config.attention_chunk_size
|
||||
chunk_block_size = math.ceil(chunk_size / self.block_size)
|
||||
for layer_idx in range(self.num_layers):
|
||||
# no_rope_layers[layer_idx] == 0 means NoPE (global)
|
||||
# Any other value means RoPE (local chunked)
|
||||
is_local_attention = no_rope_layers[layer_idx] != 0
|
||||
block_window = chunk_block_size if is_local_attention else None
|
||||
self.block_window_per_layer.append(block_window)
|
||||
logger.debug("Llama 4 block window per layer mapping: %s",
|
||||
self.block_window_per_layer)
|
||||
assert len(self.block_window_per_layer) == self.num_layers
|
||||
|
||||
# After KV Caches registered, listen for new connections.
|
||||
metadata = NixlAgentMetadata(
|
||||
engine_id=self.engine_id,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user