mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:55:46 +08:00
[V1][P/D] Local attention optimization for NIXL (#18170)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
fabe89bbc4
commit
fd195b194e
@ -96,7 +96,8 @@ class NixlConnector(KVConnectorBase_V1):
|
||||
self.connector_worker: Optional[NixlConnectorWorker] = None
|
||||
elif role == KVConnectorRole.WORKER:
|
||||
self.connector_scheduler = None
|
||||
self.connector_worker = NixlConnectorWorker(str(self.engine_id))
|
||||
self.connector_worker = NixlConnectorWorker(
|
||||
vllm_config, str(self.engine_id))
|
||||
|
||||
############################################################
|
||||
# Scheduler Side Methods
|
||||
@ -302,7 +303,7 @@ class NixlConnectorScheduler:
|
||||
class NixlConnectorWorker:
|
||||
"""Implementation of Worker side methods"""
|
||||
|
||||
def __init__(self, engine_id: str):
|
||||
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
||||
if NixlWrapper is None:
|
||||
logger.error("NIXL is not available")
|
||||
raise RuntimeError("NIXL is not available")
|
||||
@ -329,6 +330,7 @@ class NixlConnectorWorker:
|
||||
# Number of NIXL regions. Currently one region per cache
|
||||
# (so 1 per layer for MLA, otherwise 2 per layer)
|
||||
self.num_regions = 0
|
||||
self.num_layers = 0
|
||||
|
||||
# nixl_prepped_dlist_handle (int).
|
||||
self.src_xfer_side_handle: int = 0
|
||||
@ -355,6 +357,14 @@ class NixlConnectorWorker:
|
||||
# Background thread for establishing new connections.
|
||||
self._nixl_handshake_listener_t: Optional[threading.Thread] = None
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
|
||||
# TODO(mgoin): remove this once we have hybrid memory allocator
|
||||
# Optimization for models with local attention (Llama 4)
|
||||
# List of block window sizes for each layer for local attention
|
||||
self.block_window_per_layer: list[Optional[int]] = []
|
||||
|
||||
@staticmethod
|
||||
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
|
||||
ready_event: threading.Event, rank: int):
|
||||
@ -465,6 +475,27 @@ class NixlConnectorWorker:
|
||||
kv_caches_base_addr.append(base_addr)
|
||||
self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
|
||||
self.num_regions = len(caches_data)
|
||||
self.num_layers = len(self.kv_caches.keys())
|
||||
|
||||
# TODO(mgoin): remove this once we have hybrid memory allocator
|
||||
# Optimization for models with local attention (Llama 4)
|
||||
if self.vllm_config.model_config.hf_config.model_type == "llama4":
|
||||
from transformers import Llama4TextConfig
|
||||
assert isinstance(self.vllm_config.model_config.hf_text_config,
|
||||
Llama4TextConfig)
|
||||
llama4_config = self.vllm_config.model_config.hf_text_config
|
||||
no_rope_layers = llama4_config.no_rope_layers
|
||||
chunk_size = llama4_config.attention_chunk_size
|
||||
chunk_block_size = math.ceil(chunk_size / self.block_size)
|
||||
for layer_idx in range(self.num_layers):
|
||||
# no_rope_layers[layer_idx] == 0 means NoPE (global)
|
||||
# Any other value means RoPE (local chunked)
|
||||
is_local_attention = no_rope_layers[layer_idx] != 0
|
||||
block_window = chunk_block_size if is_local_attention else None
|
||||
self.block_window_per_layer.append(block_window)
|
||||
logger.debug("Llama 4 block window per layer mapping: %s",
|
||||
self.block_window_per_layer)
|
||||
assert len(self.block_window_per_layer) == self.num_layers
|
||||
|
||||
descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM")
|
||||
logger.debug("Registering descs: %s", caches_data)
|
||||
@ -699,10 +730,39 @@ class NixlConnectorWorker:
|
||||
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id]
|
||||
|
||||
# Get descs ids.
|
||||
remote_block_descs_ids = self._get_block_descs_ids(
|
||||
dst_engine_id, remote_block_ids)
|
||||
local_block_descs_ids = self._get_block_descs_ids(
|
||||
self.engine_id, local_block_ids)
|
||||
local_block_descs_ids: list[int] = []
|
||||
remote_block_descs_ids: list[int] = []
|
||||
if not self.block_window_per_layer:
|
||||
# Default case: assume global attention
|
||||
remote_block_descs_ids = self._get_block_descs_ids(
|
||||
dst_engine_id, remote_block_ids)
|
||||
local_block_descs_ids = self._get_block_descs_ids(
|
||||
self.engine_id, local_block_ids)
|
||||
else:
|
||||
# TODO(mgoin): remove this once we have hybrid memory allocator
|
||||
# Optimization for models with local attention (Llama 4)
|
||||
for layer_idx, block_window in enumerate(
|
||||
self.block_window_per_layer):
|
||||
# For each layer:
|
||||
if block_window is None:
|
||||
# If not chunked, we just use the
|
||||
# full block lists (global attention)
|
||||
layer_local_block_ids = local_block_ids
|
||||
layer_remote_block_ids = remote_block_ids
|
||||
else:
|
||||
# If chunked, get the last block_window blocks
|
||||
layer_local_block_ids = local_block_ids[-block_window:]
|
||||
layer_remote_block_ids = remote_block_ids[-block_window:]
|
||||
|
||||
# Get descs ids for the layer.
|
||||
layer_local_desc_ids = self._get_block_descs_ids(
|
||||
self.engine_id, layer_local_block_ids, layer_idx)
|
||||
layer_remote_desc_ids = self._get_block_descs_ids(
|
||||
dst_engine_id, layer_remote_block_ids, layer_idx)
|
||||
|
||||
local_block_descs_ids.extend(layer_local_desc_ids)
|
||||
remote_block_descs_ids.extend(layer_remote_desc_ids)
|
||||
|
||||
assert len(local_block_descs_ids) == len(remote_block_descs_ids)
|
||||
|
||||
# Prepare transfer with Nixl.
|
||||
@ -721,12 +781,31 @@ class NixlConnectorWorker:
|
||||
# Use handle to check completion in future step().
|
||||
self._recving_transfers[request_id].append(handle)
|
||||
|
||||
def _get_block_descs_ids(self, engine_id: str,
|
||||
block_ids: list[int]) -> list[int]:
|
||||
"""Get the descs ids for a set of block ids."""
|
||||
def _get_block_descs_ids(self,
|
||||
engine_id: str,
|
||||
block_ids: list[int],
|
||||
layer_idx: Optional[int] = None) -> list[int]:
|
||||
"""
|
||||
Get the descs ids for a set of block ids.
|
||||
If layer_idx is provided, we use the region_ids for the given layer.
|
||||
Otherwise, we use all regions.
|
||||
"""
|
||||
|
||||
if layer_idx is None:
|
||||
region_ids = range(self.num_regions)
|
||||
else:
|
||||
assert layer_idx < self.num_layers
|
||||
if self.num_layers < self.num_regions:
|
||||
# If we have more regions than layers, we assume that
|
||||
# the regions are organized as [K0, V0, K1, V1, ...]
|
||||
# and we select K_i and V_i
|
||||
assert 2 * self.num_layers == self.num_regions
|
||||
region_ids = range(2 * layer_idx, 2 * layer_idx + 2)
|
||||
else:
|
||||
# Otherwise, we assume we have MLA and select i-th layer
|
||||
assert self.num_layers == self.num_regions
|
||||
region_ids = range(layer_idx, layer_idx + 1)
|
||||
|
||||
# range(1) for MLA, range(2) otherwise.
|
||||
region_ids = range(self.num_regions)
|
||||
num_blocks = self.dst_num_blocks[engine_id]
|
||||
|
||||
# Compute the desc ids for each block.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user