diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 8f16babfe2aeb..de9cbc660666a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -14,6 +14,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional import msgspec +import numpy as np import torch import zmq @@ -1191,8 +1192,8 @@ class NixlConnectorWorker: # workers will issue xfers to parts of the P worker remote kv caches. # Get descs ids. - local_block_descs_ids: list[int] = [] - remote_block_descs_ids: list[int] = [] + local_block_descs_ids: np.ndarray + remote_block_descs_ids: np.ndarray if not self.block_window_per_layer: # Default case: assume global attention remote_block_descs_ids = self._get_block_descs_ids( @@ -1202,6 +1203,8 @@ class NixlConnectorWorker: else: # TODO(mgoin): remove this once we have hybrid memory allocator # Optimization for models with local attention (Llama 4) + local_descs_list = [] + remote_descs_list = [] for layer_idx, block_window in enumerate( self.block_window_per_layer): # For each layer: @@ -1221,8 +1224,11 @@ class NixlConnectorWorker: layer_remote_desc_ids = self._get_block_descs_ids( dst_engine_id, layer_remote_block_ids, layer_idx) - local_block_descs_ids.extend(layer_local_desc_ids) - remote_block_descs_ids.extend(layer_remote_desc_ids) + local_descs_list.append(layer_local_desc_ids) + remote_descs_list.append(layer_remote_desc_ids) + + local_block_descs_ids = np.concatenate(local_descs_list) + remote_block_descs_ids = np.concatenate(remote_descs_list) assert len(local_block_descs_ids) == len(remote_block_descs_ids) @@ -1247,14 +1253,14 @@ class NixlConnectorWorker: def _get_block_descs_ids(self, engine_id: str, block_ids: list[int], - layer_idx: Optional[int] = None) -> list[int]: + layer_idx: Optional[int] = None) -> np.ndarray: """ Get the descs ids for a set of block ids. If layer_idx is provided, we use the region_ids for the given layer. Otherwise, we use all regions. """ if layer_idx is None: - region_ids = range(self.num_regions) + region_ids = np.arange(self.num_regions) else: assert layer_idx < self.num_layers if self.num_layers < self.num_regions: @@ -1262,20 +1268,19 @@ class NixlConnectorWorker: # the regions are organized as [K0, V0, K1, V1, ...] # and we select K_i and V_i assert 2 * self.num_layers == self.num_regions - region_ids = range(2 * layer_idx, 2 * layer_idx + 2) + region_ids = np.arange(2 * layer_idx, 2 * layer_idx + 2) else: # Otherwise, we assume we have MLA and select i-th layer assert self.num_layers == self.num_regions - region_ids = range(layer_idx, layer_idx + 1) + region_ids = np.arange(layer_idx, layer_idx + 1) num_blocks = self.dst_num_blocks[engine_id] # Compute the desc ids for each block. - descs_ids: list[int] = [] - for reg_id in region_ids: - for block_id in block_ids: - descs_ids.append(reg_id * num_blocks + block_id) - return descs_ids + region_ids = region_ids[:, None] + block_ids = np.array(block_ids)[None, :] + descs_ids = region_ids * num_blocks + block_ids + return descs_ids.flatten() def get_backend_aware_kv_block_len(self): """