diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index c0c03efcdbf4..e6c83a0fc5bd 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -96,7 +96,8 @@ class NixlConnector(KVConnectorBase_V1): self.connector_worker: Optional[NixlConnectorWorker] = None elif role == KVConnectorRole.WORKER: self.connector_scheduler = None - self.connector_worker = NixlConnectorWorker(str(self.engine_id)) + self.connector_worker = NixlConnectorWorker( + vllm_config, str(self.engine_id)) ############################################################ # Scheduler Side Methods @@ -302,7 +303,7 @@ class NixlConnectorScheduler: class NixlConnectorWorker: """Implementation of Worker side methods""" - def __init__(self, engine_id: str): + def __init__(self, vllm_config: VllmConfig, engine_id: str): if NixlWrapper is None: logger.error("NIXL is not available") raise RuntimeError("NIXL is not available") @@ -329,6 +330,7 @@ class NixlConnectorWorker: # Number of NIXL regions. Currently one region per cache # (so 1 per layer for MLA, otherwise 2 per layer) self.num_regions = 0 + self.num_layers = 0 # nixl_prepped_dlist_handle (int). self.src_xfer_side_handle: int = 0 @@ -355,6 +357,14 @@ class NixlConnectorWorker: # Background thread for establishing new connections. self._nixl_handshake_listener_t: Optional[threading.Thread] = None + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + + # TODO(mgoin): remove this once we have hybrid memory allocator + # Optimization for models with local attention (Llama 4) + # List of block window sizes for each layer for local attention + self.block_window_per_layer: list[Optional[int]] = [] + @staticmethod def _nixl_handshake_listener(metadata: NixlAgentMetadata, ready_event: threading.Event, rank: int): @@ -465,6 +475,27 @@ class NixlConnectorWorker: kv_caches_base_addr.append(base_addr) self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr self.num_regions = len(caches_data) + self.num_layers = len(self.kv_caches.keys()) + + # TODO(mgoin): remove this once we have hybrid memory allocator + # Optimization for models with local attention (Llama 4) + if self.vllm_config.model_config.hf_config.model_type == "llama4": + from transformers import Llama4TextConfig + assert isinstance(self.vllm_config.model_config.hf_text_config, + Llama4TextConfig) + llama4_config = self.vllm_config.model_config.hf_text_config + no_rope_layers = llama4_config.no_rope_layers + chunk_size = llama4_config.attention_chunk_size + chunk_block_size = math.ceil(chunk_size / self.block_size) + for layer_idx in range(self.num_layers): + # no_rope_layers[layer_idx] == 0 means NoPE (global) + # Any other value means RoPE (local chunked) + is_local_attention = no_rope_layers[layer_idx] != 0 + block_window = chunk_block_size if is_local_attention else None + self.block_window_per_layer.append(block_window) + logger.debug("Llama 4 block window per layer mapping: %s", + self.block_window_per_layer) + assert len(self.block_window_per_layer) == self.num_layers descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM") logger.debug("Registering descs: %s", caches_data) @@ -699,10 +730,39 @@ class NixlConnectorWorker: remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id] # Get descs ids. - remote_block_descs_ids = self._get_block_descs_ids( - dst_engine_id, remote_block_ids) - local_block_descs_ids = self._get_block_descs_ids( - self.engine_id, local_block_ids) + local_block_descs_ids: list[int] = [] + remote_block_descs_ids: list[int] = [] + if not self.block_window_per_layer: + # Default case: assume global attention + remote_block_descs_ids = self._get_block_descs_ids( + dst_engine_id, remote_block_ids) + local_block_descs_ids = self._get_block_descs_ids( + self.engine_id, local_block_ids) + else: + # TODO(mgoin): remove this once we have hybrid memory allocator + # Optimization for models with local attention (Llama 4) + for layer_idx, block_window in enumerate( + self.block_window_per_layer): + # For each layer: + if block_window is None: + # If not chunked, we just use the + # full block lists (global attention) + layer_local_block_ids = local_block_ids + layer_remote_block_ids = remote_block_ids + else: + # If chunked, get the last block_window blocks + layer_local_block_ids = local_block_ids[-block_window:] + layer_remote_block_ids = remote_block_ids[-block_window:] + + # Get descs ids for the layer. + layer_local_desc_ids = self._get_block_descs_ids( + self.engine_id, layer_local_block_ids, layer_idx) + layer_remote_desc_ids = self._get_block_descs_ids( + dst_engine_id, layer_remote_block_ids, layer_idx) + + local_block_descs_ids.extend(layer_local_desc_ids) + remote_block_descs_ids.extend(layer_remote_desc_ids) + assert len(local_block_descs_ids) == len(remote_block_descs_ids) # Prepare transfer with Nixl. @@ -721,12 +781,31 @@ class NixlConnectorWorker: # Use handle to check completion in future step(). self._recving_transfers[request_id].append(handle) - def _get_block_descs_ids(self, engine_id: str, - block_ids: list[int]) -> list[int]: - """Get the descs ids for a set of block ids.""" + def _get_block_descs_ids(self, + engine_id: str, + block_ids: list[int], + layer_idx: Optional[int] = None) -> list[int]: + """ + Get the descs ids for a set of block ids. + If layer_idx is provided, we use the region_ids for the given layer. + Otherwise, we use all regions. + """ + + if layer_idx is None: + region_ids = range(self.num_regions) + else: + assert layer_idx < self.num_layers + if self.num_layers < self.num_regions: + # If we have more regions than layers, we assume that + # the regions are organized as [K0, V0, K1, V1, ...] + # and we select K_i and V_i + assert 2 * self.num_layers == self.num_regions + region_ids = range(2 * layer_idx, 2 * layer_idx + 2) + else: + # Otherwise, we assume we have MLA and select i-th layer + assert self.num_layers == self.num_regions + region_ids = range(layer_idx, layer_idx + 1) - # range(1) for MLA, range(2) otherwise. - region_ids = range(self.num_regions) num_blocks = self.dst_num_blocks[engine_id] # Compute the desc ids for each block.