diff --git a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh index a9817313cf02..ebc8575e5b39 100755 --- a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh @@ -49,6 +49,8 @@ NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-1} # Default to 1 PREFILLER_TP_SIZE=${PREFILLER_TP_SIZE:-1} DECODER_TP_SIZE=${DECODER_TP_SIZE:-1} GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.2} +PREFILL_BLOCK_SIZE=${PREFILL_BLOCK_SIZE:-16} +DECODE_BLOCK_SIZE=${DECODE_BLOCK_SIZE:-16} # Find the git repository root directory GIT_ROOT=$(git rev-parse --show-toplevel) @@ -136,6 +138,7 @@ run_tests_for_model() { vllm serve $model_name \ --port $PORT \ --enforce-eager \ + --block-size ${PREFILL_BLOCK_SIZE} \ --gpu-memory-utilization $GPU_MEMORY_UTILIZATION \ --tensor-parallel-size $PREFILLER_TP_SIZE \ --kv-transfer-config '$KV_CONFIG'" @@ -177,6 +180,7 @@ run_tests_for_model() { vllm serve $model_name \ --port $PORT \ --enforce-eager \ + --block-size ${DECODE_BLOCK_SIZE} \ --gpu-memory-utilization $GPU_MEMORY_UTILIZATION \ --kv-transfer-config '$KV_CONFIG'" diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 8e421717fea3..b7d7a10057b8 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -407,6 +407,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker): # `self.kv_cache_layout` is only forced to HND when vllm engine # is started. We mock HND here. kv_cache_layout="HND", + block_size=self.block_size, ), remote_tp_size=remote_tp_size, ) @@ -652,6 +653,7 @@ class TestNixlHandshake: block_lens=worker.block_len_per_layer, attn_backend_name=worker.backend_name, kv_cache_layout=mismatched_layout, + block_size=worker.block_size, ) with pytest.raises(RuntimeError): @@ -706,6 +708,7 @@ class TestNixlHandshake: block_lens=[i * 2 for i in worker.block_len_per_layer], attn_backend_name=worker.backend_name, kv_cache_layout="HND", + block_size=worker.block_size, ) # We don't check layout for homogeneous TP and MLA for now, as the diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 3d4547c51453..a70c98b63713 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -108,6 +108,7 @@ class NixlAgentMetadata(KVConnectorHandshakeMetadata): block_lens: list[int] attn_backend_name: str kv_cache_layout: str + block_size: int @dataclass @@ -709,6 +710,9 @@ class NixlConnectorWorker: self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first ) + block_size: int + remote_block_size: dict[EngineId, int] + def tp_ratio( self, remote_tp_size: int, @@ -725,6 +729,19 @@ class NixlConnectorWorker: ) return self.tp_size // remote_tp_size + def block_size_ratio( + self, + remote_block_size: int, + ) -> float: + """ + Calculate the block size ratio between local and remote TP. + """ + assert self.block_size % remote_block_size == 0, ( + f"Local block size {self.block_size} is not divisible " + f"by remote block size {remote_block_size} or vice versa." + ) + return self.block_size // remote_block_size + def tp_ratio_from_engine_id( self, remote_engine_id: EngineId, @@ -732,6 +749,13 @@ class NixlConnectorWorker: remote_tp_size = self.remote_tp_size[remote_engine_id] return self.tp_ratio(remote_tp_size) + def block_size_ratio_from_engine_id( + self, + remote_engine_id: EngineId, + ) -> float: + remote_block_size = self.remote_block_size[remote_engine_id] + return self.block_size_ratio(remote_block_size) + def is_kv_replicated(self, engine_id: EngineId) -> bool: """ Whether the KV cache is replicated across TP workers due to the @@ -866,6 +890,7 @@ class NixlConnectorWorker: # nixl_prepped_dlist_handle. self.src_xfer_side_handle: int = 0 + self.src_xfer_side_handles: dict[int, int] = {} # Map of engine_id -> nixl_prepped_dlist_handle (int)]. self.dst_xfer_side_handles: dict[EngineId, int] = {} @@ -925,6 +950,7 @@ class NixlConnectorWorker: logger.debug("Detected kv cache layout %s", self.kv_cache_layout) self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size} + self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size} # With heterogeneous TP, P must wait for all assigned D TP workers to # finish reading before safely freeing the blocks. self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int) @@ -936,6 +962,8 @@ class NixlConnectorWorker: remote_tp_size=self._tp_size, # shared state is_mla=self.use_mla, total_num_kv_heads=self.model_config.get_total_num_kv_heads(), + block_size=self.block_size, + remote_block_size=self._block_size, attn_backend=backend, ) self._use_pallas = self.kv_topo._use_pallas @@ -987,9 +1015,13 @@ class NixlConnectorWorker: ) # Register Remote agent. + assert metadata.block_size <= self.block_size, ( + "nP > nD is not supported yet." + ) remote_agent_name = self.add_remote_agent( metadata, p_remote_rank, remote_tp_size ) + setup_agent_time = time.perf_counter() logger.debug( "NIXL handshake: add agent took: %s", @@ -1217,43 +1249,10 @@ class NixlConnectorWorker: self.num_regions *= 2 # Register local/src descr for NIXL xfer. - blocks_data = [] - for i, base_addr in enumerate(seen_base_addresses): - kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) - # NOTE With heter-TP, more blocks are prepared than what are - # needed as self.num_blocks >= nixl_agent_meta.num_blocks. We - # could create fewer, but then _get_block_descs_ids needs to - # select agent_meta.num_blocks instead of self.num_blocks for - # local descr, and that makes handling regular flow less clean. - for block_id in range(self.num_blocks): - block_offset = block_id * self.block_len_per_layer[i] - addr = base_addr + block_offset - # (addr, len, device id) - blocks_data.append((addr, kv_block_len, self.device_id)) + self.seen_base_addresses = seen_base_addresses + self.src_xfer_side_handle = self.register_local_xfer_handler(self.block_size) - if self.kv_topo.is_kv_layout_blocks_first: - # Separate and interleave K/V regions to maintain the same - # descs ordering. This is needed for selecting contiguous heads - # when split across TP ranks. - for block_id in range(self.num_blocks): - block_offset = block_id * self.block_len_per_layer[i] - addr = base_addr + block_offset - # Register addresses for V cache (K registered first). - v_addr = addr + kv_block_len - blocks_data.append((v_addr, kv_block_len, self.device_id)) - logger.debug( - "Created %s blocks for src engine %s and rank %s on device id %s", - len(blocks_data), - self.engine_id, - self.tp_rank, - self.device_id, - ) - - descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) - # NIXL_INIT_AGENT to be used for preparations of local descs. - self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( - "NIXL_INIT_AGENT", descs - ) + self.src_xfer_side_handles[self.block_size] = self.src_xfer_side_handle # TODO(mgoin): Hybrid memory allocator is currently disabled for # models with local attention (Llama 4). Can remove this once enabled. @@ -1289,8 +1288,62 @@ class NixlConnectorWorker: kv_cache_layout=self.kv_cache_layout if not self.use_host_buffer else self.host_buffer_kv_cache_layout, + block_size=self.block_size, ) + def register_local_xfer_handler( + self, + block_size: int, + ) -> int: + """ + Function used for register local xfer handler with local block_size or + Remote block_size. + + When local block_size is same as remote block_size, we use local block_size + to register local_xfer_handler during init. + + When remote block size is less than local block size, we need to use + register another local_xfer_handler using remote block len to ensure + data copy correctness. + """ + block_size_ratio = self.block_size // block_size + blocks_data = [] + for i, base_addr in enumerate(self.seen_base_addresses): + # The new block_len is using prefill block_len; + # and num_blocks is multiple with N + kv_block_len = ( + self.get_backend_aware_kv_block_len(layer_idx=i) // block_size_ratio + ) + block_len_per_layer = self.block_len_per_layer[i] // block_size_ratio + num_blocks = self.num_blocks * block_size_ratio + for block_id in range(num_blocks): + block_offset = block_id * block_len_per_layer + addr = base_addr + block_offset + # (addr, len, device id) + blocks_data.append((addr, kv_block_len, self.device_id)) + + if self.kv_topo.is_kv_layout_blocks_first: + # Separate and interleave K/V regions to maintain the same + # descs ordering. This is needed for selecting contiguous heads + # when split across TP ranks. + for block_id in range(num_blocks): + block_offset = block_id * block_len_per_layer + addr = base_addr + block_offset + # Register addresses for V cache (K registered first). + v_addr = addr + kv_block_len + blocks_data.append((v_addr, kv_block_len, self.device_id)) + logger.debug( + "Created %s blocks for src engine %s and rank %s on device id %s", + len(blocks_data), + self.engine_id, + self.tp_rank, + self.device_id, + ) + + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) + # NIXL_INIT_AGENT to be used for preparations of local descs. + return self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs) + def add_remote_agent( self, nixl_agent_meta: NixlAgentMetadata, @@ -1349,6 +1402,8 @@ class NixlConnectorWorker: ### Register remote agent metadata if engine_id not in self._tp_size: self._tp_size[engine_id] = remote_tp_size + if engine_id not in self._block_size: + self._block_size[engine_id] = nixl_agent_meta.block_size remote_agent_name = self.nixl_wrapper.add_remote_agent( nixl_agent_meta.agent_metadata @@ -1359,6 +1414,13 @@ class NixlConnectorWorker: # Create dst descs and xfer side handles. TP workers have same #blocks # so we only register once per engine_id. + # Example: + # block_size_ratio > 1: + # remote: | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12| + # local origin:| 0| 1| 8| 12| + # local mapped:| 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|13|14|15| + block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(engine_id) + if engine_id not in self.dst_num_blocks: self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks @@ -1381,8 +1443,14 @@ class NixlConnectorWorker: # Register all remote blocks, but only the corresponding kv heads. for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) + remote_kv_block_len = kv_block_len // block_size_ratio + if block_size_ratio > 1: + # using remote kv_block_len as transfer unit + kv_block_len = remote_kv_block_len rank_offset = ( - self.tp_rank % tp_ratio * kv_block_len if not replicates_kv_cache else 0 + self.tp_rank % tp_ratio * remote_kv_block_len + if not replicates_kv_cache + else 0 ) for block_id in range(nixl_agent_meta.num_blocks): block_offset = block_id * nixl_agent_meta.block_lens[i] @@ -1417,6 +1485,13 @@ class NixlConnectorWorker: remote_agent_name, descs ) + if block_size_ratio > 1: + # when prefill with smaller block_size, we need to init a + # new handler with same block_len to match + self.src_xfer_side_handles[nixl_agent_meta.block_size] = ( + self.register_local_xfer_handler(nixl_agent_meta.block_size) + ) + return remote_agent_name def _validate_remote_agent_handshake( @@ -1433,6 +1508,9 @@ class NixlConnectorWorker: assert nixl_agent_meta.attn_backend_name == self.backend_name tp_ratio = self.kv_topo.tp_ratio_from_engine_id(remote_engine_id) + block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id( + remote_engine_id + ) assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP" assert not self._use_pallas or tp_ratio == 1, ( "TPU (pallas_v1) DOES NOT support heterogeneous TP yet." @@ -1463,33 +1541,26 @@ class NixlConnectorWorker: remote_block_len = nixl_agent_meta.block_lens[0] if self.use_mla or self.kv_topo.is_kv_replicated(remote_engine_id): # With replicated KV cache, only the number of blocks can differ. - assert self.block_len_per_layer == nixl_agent_meta.block_lens, ( - "KV cache sizes must match between P and D when replicated" - ) - remote_block_size = remote_block_len // (self.slot_size_per_layer[0]) + for i in range(len(self.block_len_per_layer)): + assert ( + self.block_len_per_layer[i] // block_size_ratio + == nixl_agent_meta.block_lens[i] + ), "KV cache sizes must match between P and D when replicated" else: # When MLA is not used, this is a list of the same block length for block_len in nixl_agent_meta.block_lens: assert block_len == remote_block_len, ( "All remote layers must have the same block size" ) - remote_block_size = remote_block_len // ( - self.slot_size_per_layer[0] * tp_ratio - ) - if self.kv_topo.is_kv_layout_blocks_first: - # With flashinfer, KV are sent in the same message. - remote_block_size //= 2 - assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, ( + assert ( + remote_block_len + == (self.block_len_per_layer[0] * tp_ratio) // block_size_ratio + ), ( "Remote P worker KV layer cache must be of shape [2, N, " "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." ) - assert self.block_size == remote_block_size, ( - "Remote P worker with different page/block size is not supported " - f"{self.block_size=}, {remote_block_size=}" - ) - # TP workers have same #blocks. assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks @@ -1576,6 +1647,56 @@ class NixlConnectorWorker: ) cache.index_copy_(0, indices, permuted_blocks) + def blocksize_post_process(self, block_ids_per_ratio: dict[float, list[list[int]]]): + def _process_local_gt_remote(blocks_to_update, block_size_ratio): + n_kv_heads, block_size, head_size = blocks_to_update.shape[1:] + remote_block_size = block_size // block_size_ratio + n_blocks = block_size_ratio + # actual permute is to convert + # for local blocksize > remote blocksize + # ex: local blocksize = 16 tokens, remote blocksize = 4 tokens + # local block[0] = remote block[0, 1, 2, 3] + # remote is |h0-b0|h1-b0|h2-b0|h3-b0|h0-b1|h1-b1|h2-b1|h3-b1|... + # local is |h0-b0..................|h1-b0..................|... + # permute is to: + # 1. view => view remote as n_blocks * remote_shape(H,remoteN,D) + # 2. permute => (H, nblocks, remoteN, D) + # 3. flatten => (H, localN, D) + permuted_blocks = ( + blocks_to_update.reshape( + -1, n_blocks, n_kv_heads, remote_block_size, head_size + ) + .permute(0, 2, 1, 3, 4) + .flatten(2, 3) + ) + return permuted_blocks + + if len(self.device_kv_caches) == 0: + return + split_k_and_v = not ( + self.use_mla or self._use_pallas or self.kv_topo.is_kv_layout_blocks_first + ) + sample_cache = list(self.device_kv_caches.values())[0][0] + for block_size_ratio, block_ids_list in block_ids_per_ratio.items(): + assert block_size_ratio > 1, "Only nP < nD supported currently." + block_ids_list = [[item for sublist in block_ids_list for item in sublist]] + + for block_ids in block_ids_list: + indices = torch.tensor(block_ids, device=sample_cache.device) + + for _, cache_or_caches in self.device_kv_caches.items(): + cache_list = cache_or_caches if split_k_and_v else [cache_or_caches] + for cache in cache_list: + blocks_to_update = cache.index_select(0, indices) + # because kv_cache is always using original layout NHD as + # virtual shape while stride can be either HND / NHD at + # initialization. + # we need to firstly get physical view of the tensor + permuted_blocks = _process_local_gt_remote( + blocks_to_update.permute(0, 2, 1, 3), block_size_ratio + ).permute(0, 2, 1, 3) + cache.index_copy_(0, indices, permuted_blocks) + def get_finished(self) -> tuple[set[str], set[str]]: """ Get requests that are done sending or recving on this specific worker. @@ -1599,6 +1720,7 @@ class NixlConnectorWorker: ) block_ids_to_permute = [] + block_ids_for_blocksize_post_process = defaultdict(list) for req_id in done_recving: # clean up metadata for completed requests meta = self._recving_metadata.pop(req_id, None) @@ -1607,6 +1729,20 @@ class NixlConnectorWorker: self.sync_recved_kv_to_device(req_id, meta) if self.enable_permute_local_kv: block_ids_to_permute += meta.local_physical_block_ids + + # post processing for heteroblocksize + block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id( + meta.remote_engine_id + ) + if ( + not self.use_mla + and block_size_ratio > 1 + and self.kv_cache_layout == "HND" + ): + block_ids_for_blocksize_post_process[block_size_ratio].append( + meta.local_block_ids + ) + self.blocksize_post_process(block_ids_for_blocksize_post_process) if len(block_ids_to_permute) > 0: self.permute_device_kv(block_ids_to_permute) @@ -1781,6 +1917,24 @@ class NixlConnectorWorker: dst_engine_id: str, request_id: str, ): + block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id) + if block_size_ratio > 1: + local_block_ids = self.get_mapped_blocks( + np.asarray(local_block_ids), block_size_ratio + ) + if len(local_block_ids) > len(remote_block_ids): + # NOTE: + # get_mapped_blocks will always expand block_ids for n times. + # ex: + # prefill block_ids with block_size as 4: + # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + # Local decode block_ids with block_size as 16: [1, 2, 3] + # expland ecode block_ids with get_mapped_blocks from [1, 2, 3] to + # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] + # Then we clip local to align with prefill + # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] to + # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + local_block_ids = local_block_ids[: len(remote_block_ids)] # NOTE(rob): having the staging blocks be on the READER side is # not going to work well (since we will have to call rearrange tensors). # after we detect the txn is complete (which means we cannot make the @@ -1823,7 +1977,10 @@ class NixlConnectorWorker: remote_block_ids = remote_block_ids[-num_local_blocks:] # Get side handles. - local_xfer_side_handle = self.src_xfer_side_handle + remote_block_size = self.kv_topo.remote_block_size[dst_engine_id] + local_xfer_side_handle = self.src_xfer_side_handles.get( + remote_block_size, self.src_xfer_side_handle + ) remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id] # NOTE (nicolo) With homogeneous TP, each TP worker loads KV from @@ -1833,13 +1990,17 @@ class NixlConnectorWorker: # Get descs ids. local_block_descs_ids: np.ndarray remote_block_descs_ids: np.ndarray + if not self.block_window_per_layer: # Default case: assume global attention remote_block_descs_ids = self._get_block_descs_ids( - dst_engine_id, remote_block_ids + dst_engine_id, + remote_block_ids, ) local_block_descs_ids = self._get_block_descs_ids( - self.engine_id, local_block_ids + self.engine_id, + local_block_ids, + block_size_ratio=block_size_ratio, ) else: # TODO(mgoin): remove this once we have hybrid memory allocator @@ -1860,10 +2021,15 @@ class NixlConnectorWorker: # Get descs ids for the layer. layer_local_desc_ids = self._get_block_descs_ids( - self.engine_id, layer_local_block_ids, layer_idx + dst_engine_id, + layer_local_block_ids, + layer_idx, ) layer_remote_desc_ids = self._get_block_descs_ids( - dst_engine_id, layer_remote_block_ids, layer_idx + self.engine_id, + layer_remote_block_ids, + layer_idx, + block_size_ratio=block_size_ratio, ) local_descs_list.append(layer_local_desc_ids) @@ -1905,8 +2071,31 @@ class NixlConnectorWorker: self.nixl_wrapper.release_xfer_handle(handle) self._failed_recv_reqs.add(request_id) + def get_mapped_blocks(self, block_ids, block_size_ratio): + """ + Calculates the new set of block IDs by mapping every element + in the (potentially sparse) input array. + Example: block_ids=[0, 2], block_size_ratio=2 + get_mapped_blocks 0 1 [2 3] 4 5 + # remote is |h0-b0|h1-b0||h0-b1|h1-b1||h0-b1|h1-b1|| + # local is |h0-b0......||h1-b0......||h2-b0........ + local_block_ids 0 [1] 2 + """ + if block_ids.size == 0: + return np.array([], dtype=np.int64) + + start_ids = block_ids * block_size_ratio + offsets = np.arange(block_size_ratio) + mapped_2d = start_ids[:, None] + offsets[None, :] + + return mapped_2d.flatten().astype(np.int64) + def _get_block_descs_ids( - self, engine_id: str, block_ids: list[int], layer_idx: int | None = None + self, + engine_id: str, + block_ids: list[int], + layer_idx: int | None = None, + block_size_ratio: float | None = None, ) -> np.ndarray: """ Get the descs ids for a set of block ids. @@ -1929,6 +2118,8 @@ class NixlConnectorWorker: region_ids = np.arange(layer_idx, layer_idx + 1) num_blocks = self.dst_num_blocks[engine_id] + if block_size_ratio is not None: + num_blocks = int(num_blocks * block_size_ratio) # Compute the desc ids for each block. region_ids = region_ids[:, None]