diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 475cf2285e39..8e421717fea3 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -1096,7 +1096,8 @@ def _run_abort_timeout_test(llm: LLM, timeout: int): llm.llm_engine.engine_core.shutdown() -def test_register_kv_caches(dist_init): +@pytest.mark.parametrize("attn_backend", ["FLASH_ATTN", "TRITON_ATTN"]) +def test_register_kv_caches(dist_init, attn_backend, monkeypatch): """ Test that register_kv_caches() properly calls nixl_wrapper methods with correct data. @@ -1108,10 +1109,22 @@ def test_register_kv_caches(dist_init): block layout info """ + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend) + vllm_config = create_vllm_config() + # Import the appropriate backend based on the parameter + if attn_backend == "FLASH_ATTN": + from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend + + backend_cls = FlashAttentionBackend + else: # TRITON_ATTN + from vllm.v1.attention.backends.triton_attn import TritonAttentionBackend + + backend_cls = TritonAttentionBackend + # Create test kv cache tensors using proper backend shape - kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape( + kv_cache_shape = backend_cls.get_kv_cache_shape( num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 ) shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 6c20eee1ecbf..375ea79d0e81 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -21,6 +21,7 @@ import torch import zmq from vllm import envs +from vllm.attention import AttentionBackend from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.selector import get_attn_backend from vllm.config import VllmConfig @@ -669,6 +670,33 @@ class NixlConnectorWorker: remote_tp_size: dict[EngineId, int] is_mla: bool total_num_kv_heads: int + attn_backend: type[AttentionBackend] + + def __post_init__(self): + # Figure out whether the first dimension of the cache is K/V + # or num_blocks. This is used to register the memory regions correctly. + kv_cache_shape = self.attn_backend.get_kv_cache_shape( + num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 + ) + # Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D], + # we just mock num_blocks to 1 for the dimension check below. + self._is_kv_layout_blocks_first = ( + len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1 + ) + + attn_backend = AttentionBackendEnum[self.attn_backend.get_name()] + self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS + + @property + def is_kv_layout_blocks_first(self) -> bool: + return self._is_kv_layout_blocks_first + + @property + def split_k_and_v(self) -> bool: + # Whether to register regions for K and V separately (when present). + return not ( + self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first + ) def tp_ratio( self, @@ -876,9 +904,6 @@ class NixlConnectorWorker: use_mla=self.use_mla, ) self.backend_name = backend.get_name() - attn_backend = AttentionBackendEnum[self.backend_name] - self._use_flashinfer = attn_backend == AttentionBackendEnum.FLASHINFER - self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS self.kv_cache_layout = get_kv_cache_layout() self.host_buffer_kv_cache_layout = self.kv_cache_layout logger.debug("Detected attention backend %s", self.backend_name) @@ -896,7 +921,9 @@ class NixlConnectorWorker: remote_tp_size=self._tp_size, # shared state is_mla=self.use_mla, total_num_kv_heads=self.model_config.get_total_num_kv_heads(), + attn_backend=backend, ) + self._use_pallas = self.kv_topo._use_pallas def _nixl_handshake( self, @@ -1076,7 +1103,7 @@ class NixlConnectorWorker: # (roughly 8KB vs 5KB). # Conversely for FlashInfer, K and V are registered in the same region # to better exploit the memory layout (ie num_blocks is the first dim). - split_k_and_v = not (self.use_mla or self._use_pallas or self._use_flashinfer) + split_k_and_v = self.kv_topo.split_k_and_v tensor_size_bytes = None # Enable different block lengths for different layers when MLA is used. self.block_len_per_layer = list[int]() @@ -1141,7 +1168,7 @@ class NixlConnectorWorker: self.device_kv_caches = kv_caches self.dst_num_blocks[self.engine_id] = self.num_blocks - if self._use_flashinfer: + if self.kv_topo.is_kv_layout_blocks_first: for i in range(len(self.slot_size_per_layer)): assert self.slot_size_per_layer[i] % 2 == 0 self.slot_size_per_layer[i] //= 2 @@ -1169,7 +1196,7 @@ class NixlConnectorWorker: # (addr, len, device id) blocks_data.append((addr, kv_block_len, self.device_id)) - if self._use_flashinfer: + if self.kv_topo.is_kv_layout_blocks_first: # Separate and interleave K/V regions to maintain the same # descs ordering. This is needed for selecting contiguous heads # when split across TP ranks. @@ -1331,7 +1358,7 @@ class NixlConnectorWorker: # (addr, len, device id) blocks_data.append((addr, kv_block_len, nixl_agent_meta.device_id)) - if self._use_flashinfer: + if self.kv_topo.is_kv_layout_blocks_first: # With FlashInfer index V separately to allow head splitting. for block_id in range(nixl_agent_meta.num_blocks): block_offset = block_id * nixl_agent_meta.block_lens[i] @@ -1414,7 +1441,7 @@ class NixlConnectorWorker: remote_block_size = remote_block_len // ( self.slot_size_per_layer[0] * tp_ratio ) - if self._use_flashinfer: + if self.kv_topo.is_kv_layout_blocks_first: # With flashinfer, KV are sent in the same message. remote_block_size //= 2 @@ -1494,7 +1521,7 @@ class NixlConnectorWorker: - cache.index_copy_(0, indices, permuted_blocks) # copy permuted kv back """ - split_k_and_v = not (self.use_mla or self._use_pallas or self._use_flashinfer) + split_k_and_v = self.kv_topo.split_k_and_v inv_order = [0, 2, 1, 3] sample_cache = list(self.device_kv_caches.values())[0][0] target_shape = list(sample_cache.shape) @@ -1874,7 +1901,7 @@ class NixlConnectorWorker: For FlashInfer, this is half the length of the whole block, as K and V share the same region. """ - if self._use_flashinfer: + if self.kv_topo.is_kv_layout_blocks_first: # For indexing only half (either just the K or V part). block_len = self.block_len_per_layer[layer_idx] // 2 else: