[V1][P/D] Local attention optimization for NIXL (#18170)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-05-16 21:16:33 -04:00 committed by GitHub
parent fabe89bbc4
commit fd195b194e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -96,7 +96,8 @@ class NixlConnector(KVConnectorBase_V1):
self.connector_worker: Optional[NixlConnectorWorker] = None
elif role == KVConnectorRole.WORKER:
self.connector_scheduler = None
self.connector_worker = NixlConnectorWorker(str(self.engine_id))
self.connector_worker = NixlConnectorWorker(
vllm_config, str(self.engine_id))
############################################################
# Scheduler Side Methods
@ -302,7 +303,7 @@ class NixlConnectorScheduler:
class NixlConnectorWorker:
"""Implementation of Worker side methods"""
def __init__(self, engine_id: str):
def __init__(self, vllm_config: VllmConfig, engine_id: str):
if NixlWrapper is None:
logger.error("NIXL is not available")
raise RuntimeError("NIXL is not available")
@ -329,6 +330,7 @@ class NixlConnectorWorker:
# Number of NIXL regions. Currently one region per cache
# (so 1 per layer for MLA, otherwise 2 per layer)
self.num_regions = 0
self.num_layers = 0
# nixl_prepped_dlist_handle (int).
self.src_xfer_side_handle: int = 0
@ -355,6 +357,14 @@ class NixlConnectorWorker:
# Background thread for establishing new connections.
self._nixl_handshake_listener_t: Optional[threading.Thread] = None
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
# TODO(mgoin): remove this once we have hybrid memory allocator
# Optimization for models with local attention (Llama 4)
# List of block window sizes for each layer for local attention
self.block_window_per_layer: list[Optional[int]] = []
@staticmethod
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
ready_event: threading.Event, rank: int):
@ -465,6 +475,27 @@ class NixlConnectorWorker:
kv_caches_base_addr.append(base_addr)
self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
self.num_regions = len(caches_data)
self.num_layers = len(self.kv_caches.keys())
# TODO(mgoin): remove this once we have hybrid memory allocator
# Optimization for models with local attention (Llama 4)
if self.vllm_config.model_config.hf_config.model_type == "llama4":
from transformers import Llama4TextConfig
assert isinstance(self.vllm_config.model_config.hf_text_config,
Llama4TextConfig)
llama4_config = self.vllm_config.model_config.hf_text_config
no_rope_layers = llama4_config.no_rope_layers
chunk_size = llama4_config.attention_chunk_size
chunk_block_size = math.ceil(chunk_size / self.block_size)
for layer_idx in range(self.num_layers):
# no_rope_layers[layer_idx] == 0 means NoPE (global)
# Any other value means RoPE (local chunked)
is_local_attention = no_rope_layers[layer_idx] != 0
block_window = chunk_block_size if is_local_attention else None
self.block_window_per_layer.append(block_window)
logger.debug("Llama 4 block window per layer mapping: %s",
self.block_window_per_layer)
assert len(self.block_window_per_layer) == self.num_layers
descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM")
logger.debug("Registering descs: %s", caches_data)
@ -699,10 +730,39 @@ class NixlConnectorWorker:
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id]
# Get descs ids.
remote_block_descs_ids = self._get_block_descs_ids(
dst_engine_id, remote_block_ids)
local_block_descs_ids = self._get_block_descs_ids(
self.engine_id, local_block_ids)
local_block_descs_ids: list[int] = []
remote_block_descs_ids: list[int] = []
if not self.block_window_per_layer:
# Default case: assume global attention
remote_block_descs_ids = self._get_block_descs_ids(
dst_engine_id, remote_block_ids)
local_block_descs_ids = self._get_block_descs_ids(
self.engine_id, local_block_ids)
else:
# TODO(mgoin): remove this once we have hybrid memory allocator
# Optimization for models with local attention (Llama 4)
for layer_idx, block_window in enumerate(
self.block_window_per_layer):
# For each layer:
if block_window is None:
# If not chunked, we just use the
# full block lists (global attention)
layer_local_block_ids = local_block_ids
layer_remote_block_ids = remote_block_ids
else:
# If chunked, get the last block_window blocks
layer_local_block_ids = local_block_ids[-block_window:]
layer_remote_block_ids = remote_block_ids[-block_window:]
# Get descs ids for the layer.
layer_local_desc_ids = self._get_block_descs_ids(
self.engine_id, layer_local_block_ids, layer_idx)
layer_remote_desc_ids = self._get_block_descs_ids(
dst_engine_id, layer_remote_block_ids, layer_idx)
local_block_descs_ids.extend(layer_local_desc_ids)
remote_block_descs_ids.extend(layer_remote_desc_ids)
assert len(local_block_descs_ids) == len(remote_block_descs_ids)
# Prepare transfer with Nixl.
@ -721,12 +781,31 @@ class NixlConnectorWorker:
# Use handle to check completion in future step().
self._recving_transfers[request_id].append(handle)
def _get_block_descs_ids(self, engine_id: str,
block_ids: list[int]) -> list[int]:
"""Get the descs ids for a set of block ids."""
def _get_block_descs_ids(self,
engine_id: str,
block_ids: list[int],
layer_idx: Optional[int] = None) -> list[int]:
"""
Get the descs ids for a set of block ids.
If layer_idx is provided, we use the region_ids for the given layer.
Otherwise, we use all regions.
"""
if layer_idx is None:
region_ids = range(self.num_regions)
else:
assert layer_idx < self.num_layers
if self.num_layers < self.num_regions:
# If we have more regions than layers, we assume that
# the regions are organized as [K0, V0, K1, V1, ...]
# and we select K_i and V_i
assert 2 * self.num_layers == self.num_regions
region_ids = range(2 * layer_idx, 2 * layer_idx + 2)
else:
# Otherwise, we assume we have MLA and select i-th layer
assert self.num_layers == self.num_regions
region_ids = range(layer_idx, layer_idx + 1)
# range(1) for MLA, range(2) otherwise.
region_ids = range(self.num_regions)
num_blocks = self.dst_num_blocks[engine_id]
# Compute the desc ids for each block.