diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index b02d9a657407b..b95c8df3469b3 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -985,8 +985,10 @@ def test_hybrid_block_table_initialization(): req_index = 0 block_table.append_row(kvcache_manager_blocks, req_index) # Get expected kernel blocks from the implementation for verification. - expected_kernel_blocks = block_table._map_to_kernel_blocks( - np.array(kvcache_manager_blocks) + expected_kernel_blocks = block_table.map_to_kernel_blocks( + np.array(kvcache_manager_blocks), + block_table.blocks_per_kv_block, + block_table._kernel_block_arange, ) # Verify block table state assert block_table.num_blocks_per_row[req_index] == len(expected_kernel_blocks) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 42433c717cf26..3d4547c514532 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -49,6 +49,7 @@ from vllm.platforms import current_platform from vllm.utils.network_utils import make_zmq_path, make_zmq_socket from vllm.v1.attention.backends.utils import get_kv_cache_layout from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.worker.block_table import BlockTable if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata @@ -112,6 +113,8 @@ class NixlAgentMetadata(KVConnectorHandshakeMetadata): @dataclass class ReqMeta: local_block_ids: list[int] + # To be used when logical block size does not match the kernel block size + local_physical_block_ids: list[int] remote_block_ids: list[int] remote_host: str remote_port: int @@ -139,6 +142,7 @@ class NixlConnectorMetadata(KVConnectorMetadata): assert load_remote_cache ^ save_to_host _req = ReqMeta( local_block_ids=local_block_ids, + local_physical_block_ids=local_block_ids, remote_block_ids=kv_transfer_params["remote_block_ids"], remote_engine_id=kv_transfer_params["remote_engine_id"], remote_host=kv_transfer_params["remote_host"], @@ -935,6 +939,7 @@ class NixlConnectorWorker: attn_backend=backend, ) self._use_pallas = self.kv_topo._use_pallas + self._physical_blocks_per_logical_kv_block = 1 def _nixl_handshake( self, @@ -1133,6 +1138,22 @@ class NixlConnectorWorker: if base_addr in seen_base_addresses: continue + # TODO (NickLucche): Get kernel_block_size in a cleaner way + # NHD default "view" for non-MLA cache + kernel_block_size = cache.shape[-2] if self.use_mla else cache.shape[-3] + + if self.block_size != kernel_block_size: + logger.info_once( + "User-specified logical block size (%s) does not match" + " physical kernel block size (%s). Using the latter. ", + self.block_size, + kernel_block_size, + ) + self._physical_blocks_per_logical_kv_block = ( + self.block_size // kernel_block_size + ) + self.block_size = kernel_block_size + seen_base_addresses.append(base_addr) curr_tensor_size_bytes = cache.numel() * cache.element_size() @@ -1479,7 +1500,7 @@ class NixlConnectorWorker: assert self.use_host_buffer assert self.copy_blocks is not None - local_block_ids = meta.local_block_ids + local_block_ids = meta.local_physical_block_ids self.copy_blocks( self.host_xfer_buffers, self.device_kv_caches, @@ -1492,7 +1513,7 @@ class NixlConnectorWorker: "synced recved kv of request[%s] to device kv buffer," "local_block_ids: %s. ", req_id, - ",".join(map(str, meta.local_block_ids)), + ",".join(map(str, local_block_ids)), ) def save_kv_to_host(self, metadata: NixlConnectorMetadata): @@ -1501,19 +1522,22 @@ class NixlConnectorWorker: assert self.copy_blocks is not None for req_id, meta in metadata.reqs_to_save.items(): + meta.local_physical_block_ids = self._logical_to_kernel_block_ids( + meta.local_block_ids + ) if logger.isEnabledFor(logging.DEBUG): logger.debug( "save_load_kv for request[%s] to host xfer buffer." "local_block_ids: %s. ", req_id, - ",".join(map(str, meta.local_block_ids)), + ",".join(map(str, meta.local_physical_block_ids)), ) # blocking self.copy_blocks( self.device_kv_caches, self.host_xfer_buffers, - meta.local_block_ids, - meta.local_block_ids, + meta.local_physical_block_ids, + meta.local_physical_block_ids, "d2h", ) @@ -1582,7 +1606,7 @@ class NixlConnectorWorker: if self.use_host_buffer: self.sync_recved_kv_to_device(req_id, meta) if self.enable_permute_local_kv: - block_ids_to_permute += meta.local_block_ids + block_ids_to_permute += meta.local_physical_block_ids if len(block_ids_to_permute) > 0: self.permute_device_kv(block_ids_to_permute) @@ -1669,7 +1693,7 @@ class NixlConnectorWorker: req_id, xfer_state, ) - # mark all blocks for this request as invalid + # mark all (logical)blocks for this request as invalid if meta := self._recving_metadata.pop(req_id, None): self._invalid_block_ids.update(meta.local_block_ids) self._recving_metadata.pop(req_id, None) @@ -1686,13 +1710,19 @@ class NixlConnectorWorker: We check for these trnxs to complete in each step(). """ for req_id, meta in metadata.reqs_to_recv.items(): + meta.local_physical_block_ids = self._logical_to_kernel_block_ids( + meta.local_block_ids + ) + meta.remote_block_ids = self._logical_to_kernel_block_ids( + meta.remote_block_ids + ) remote_engine_id = meta.remote_engine_id logger.debug( "start_load_kv for request %s from remote engine %s. " "Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id, remote_engine_id, - len(meta.local_block_ids), + len(meta.local_physical_block_ids), len(meta.remote_block_ids), ) # always store metadata for failure recovery @@ -1740,7 +1770,7 @@ class NixlConnectorWorker: self._read_blocks( request_id=req_id, dst_engine_id=meta.remote_engine_id, - local_block_ids=meta.local_block_ids, + local_block_ids=meta.local_physical_block_ids, remote_block_ids=meta.remote_block_ids, ) @@ -1867,7 +1897,7 @@ class NixlConnectorWorker: "Marking blocks as invalid.", request_id, ) - # mark all blocks for this request as invalid + # mark all (logical) blocks for this request as invalid if meta := self._recving_metadata.get(request_id): self._invalid_block_ids.update(meta.local_block_ids) self.xfer_stats.record_failed_transfer() @@ -1906,6 +1936,23 @@ class NixlConnectorWorker: descs_ids = region_ids * num_blocks + block_ids return descs_ids.flatten() + def _logical_to_kernel_block_ids(self, block_ids: list[int]) -> list[int]: + """ + Convert logical block ids to kernel physical block ids. + This is required when the logical block size (the one set by the user) + does not match the one required by the attn backend. + """ + if self._physical_blocks_per_logical_kv_block == 1: + # Noop when physical and logical block sizes are the same + return block_ids + block_ids_np = np.array(block_ids) + block_arange = np.arange(0, self._physical_blocks_per_logical_kv_block).reshape( + 1, -1 + ) + return BlockTable.map_to_kernel_blocks( + block_ids_np, self._physical_blocks_per_logical_kv_block, block_arange + ).tolist() + def get_backend_aware_kv_block_len(self, layer_idx: int): """ Get the block length for one K/V element (K and V have the same size). diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index c28bf542f85c5..9f6c19e464308 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -98,7 +98,9 @@ class BlockTable: return if self.use_hybrid_blocks: - block_ids = self._map_to_kernel_blocks(np.array(block_ids)) + block_ids = self.map_to_kernel_blocks( + np.array(block_ids), self.blocks_per_kv_block, self._kernel_block_arange + ) num_blocks = len(block_ids) start = self.num_blocks_per_row[row_idx] @@ -188,7 +190,12 @@ class BlockTable: self.block_table.gpu.fill_(0) self.block_table.cpu.fill_(0) - def _map_to_kernel_blocks(self, kv_manager_block_ids: np.ndarray) -> np.ndarray: + @staticmethod + def map_to_kernel_blocks( + kv_manager_block_ids: np.ndarray, + blocks_per_kv_block: int, + kernel_block_arange: np.ndarray, + ) -> np.ndarray: """Convert kv_manager_block_id IDs to kernel block IDs. Example: @@ -203,12 +210,12 @@ class BlockTable: # kv_manager_block_id 1 → kernel block id [2, 3] # kv_manager_block_id 2 → kernel block id [4, 5] """ - if not self.use_hybrid_blocks: + if blocks_per_kv_block == 1: return kv_manager_block_ids kernel_block_ids = ( - kv_manager_block_ids.reshape(-1, 1) * self.blocks_per_kv_block - + self._kernel_block_arange + kv_manager_block_ids.reshape(-1, 1) * blocks_per_kv_block + + kernel_block_arange ) return kernel_block_ids.reshape(-1)