mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 23:15:39 +08:00
[V1][P/D] Local attention optimization for NIXL (#18170)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
fabe89bbc4
commit
fd195b194e
@ -96,7 +96,8 @@ class NixlConnector(KVConnectorBase_V1):
|
|||||||
self.connector_worker: Optional[NixlConnectorWorker] = None
|
self.connector_worker: Optional[NixlConnectorWorker] = None
|
||||||
elif role == KVConnectorRole.WORKER:
|
elif role == KVConnectorRole.WORKER:
|
||||||
self.connector_scheduler = None
|
self.connector_scheduler = None
|
||||||
self.connector_worker = NixlConnectorWorker(str(self.engine_id))
|
self.connector_worker = NixlConnectorWorker(
|
||||||
|
vllm_config, str(self.engine_id))
|
||||||
|
|
||||||
############################################################
|
############################################################
|
||||||
# Scheduler Side Methods
|
# Scheduler Side Methods
|
||||||
@ -302,7 +303,7 @@ class NixlConnectorScheduler:
|
|||||||
class NixlConnectorWorker:
|
class NixlConnectorWorker:
|
||||||
"""Implementation of Worker side methods"""
|
"""Implementation of Worker side methods"""
|
||||||
|
|
||||||
def __init__(self, engine_id: str):
|
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
||||||
if NixlWrapper is None:
|
if NixlWrapper is None:
|
||||||
logger.error("NIXL is not available")
|
logger.error("NIXL is not available")
|
||||||
raise RuntimeError("NIXL is not available")
|
raise RuntimeError("NIXL is not available")
|
||||||
@ -329,6 +330,7 @@ class NixlConnectorWorker:
|
|||||||
# Number of NIXL regions. Currently one region per cache
|
# Number of NIXL regions. Currently one region per cache
|
||||||
# (so 1 per layer for MLA, otherwise 2 per layer)
|
# (so 1 per layer for MLA, otherwise 2 per layer)
|
||||||
self.num_regions = 0
|
self.num_regions = 0
|
||||||
|
self.num_layers = 0
|
||||||
|
|
||||||
# nixl_prepped_dlist_handle (int).
|
# nixl_prepped_dlist_handle (int).
|
||||||
self.src_xfer_side_handle: int = 0
|
self.src_xfer_side_handle: int = 0
|
||||||
@ -355,6 +357,14 @@ class NixlConnectorWorker:
|
|||||||
# Background thread for establishing new connections.
|
# Background thread for establishing new connections.
|
||||||
self._nixl_handshake_listener_t: Optional[threading.Thread] = None
|
self._nixl_handshake_listener_t: Optional[threading.Thread] = None
|
||||||
|
|
||||||
|
self.vllm_config = vllm_config
|
||||||
|
self.block_size = vllm_config.cache_config.block_size
|
||||||
|
|
||||||
|
# TODO(mgoin): remove this once we have hybrid memory allocator
|
||||||
|
# Optimization for models with local attention (Llama 4)
|
||||||
|
# List of block window sizes for each layer for local attention
|
||||||
|
self.block_window_per_layer: list[Optional[int]] = []
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
|
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
|
||||||
ready_event: threading.Event, rank: int):
|
ready_event: threading.Event, rank: int):
|
||||||
@ -465,6 +475,27 @@ class NixlConnectorWorker:
|
|||||||
kv_caches_base_addr.append(base_addr)
|
kv_caches_base_addr.append(base_addr)
|
||||||
self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
|
self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
|
||||||
self.num_regions = len(caches_data)
|
self.num_regions = len(caches_data)
|
||||||
|
self.num_layers = len(self.kv_caches.keys())
|
||||||
|
|
||||||
|
# TODO(mgoin): remove this once we have hybrid memory allocator
|
||||||
|
# Optimization for models with local attention (Llama 4)
|
||||||
|
if self.vllm_config.model_config.hf_config.model_type == "llama4":
|
||||||
|
from transformers import Llama4TextConfig
|
||||||
|
assert isinstance(self.vllm_config.model_config.hf_text_config,
|
||||||
|
Llama4TextConfig)
|
||||||
|
llama4_config = self.vllm_config.model_config.hf_text_config
|
||||||
|
no_rope_layers = llama4_config.no_rope_layers
|
||||||
|
chunk_size = llama4_config.attention_chunk_size
|
||||||
|
chunk_block_size = math.ceil(chunk_size / self.block_size)
|
||||||
|
for layer_idx in range(self.num_layers):
|
||||||
|
# no_rope_layers[layer_idx] == 0 means NoPE (global)
|
||||||
|
# Any other value means RoPE (local chunked)
|
||||||
|
is_local_attention = no_rope_layers[layer_idx] != 0
|
||||||
|
block_window = chunk_block_size if is_local_attention else None
|
||||||
|
self.block_window_per_layer.append(block_window)
|
||||||
|
logger.debug("Llama 4 block window per layer mapping: %s",
|
||||||
|
self.block_window_per_layer)
|
||||||
|
assert len(self.block_window_per_layer) == self.num_layers
|
||||||
|
|
||||||
descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM")
|
descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM")
|
||||||
logger.debug("Registering descs: %s", caches_data)
|
logger.debug("Registering descs: %s", caches_data)
|
||||||
@ -699,10 +730,39 @@ class NixlConnectorWorker:
|
|||||||
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id]
|
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id]
|
||||||
|
|
||||||
# Get descs ids.
|
# Get descs ids.
|
||||||
|
local_block_descs_ids: list[int] = []
|
||||||
|
remote_block_descs_ids: list[int] = []
|
||||||
|
if not self.block_window_per_layer:
|
||||||
|
# Default case: assume global attention
|
||||||
remote_block_descs_ids = self._get_block_descs_ids(
|
remote_block_descs_ids = self._get_block_descs_ids(
|
||||||
dst_engine_id, remote_block_ids)
|
dst_engine_id, remote_block_ids)
|
||||||
local_block_descs_ids = self._get_block_descs_ids(
|
local_block_descs_ids = self._get_block_descs_ids(
|
||||||
self.engine_id, local_block_ids)
|
self.engine_id, local_block_ids)
|
||||||
|
else:
|
||||||
|
# TODO(mgoin): remove this once we have hybrid memory allocator
|
||||||
|
# Optimization for models with local attention (Llama 4)
|
||||||
|
for layer_idx, block_window in enumerate(
|
||||||
|
self.block_window_per_layer):
|
||||||
|
# For each layer:
|
||||||
|
if block_window is None:
|
||||||
|
# If not chunked, we just use the
|
||||||
|
# full block lists (global attention)
|
||||||
|
layer_local_block_ids = local_block_ids
|
||||||
|
layer_remote_block_ids = remote_block_ids
|
||||||
|
else:
|
||||||
|
# If chunked, get the last block_window blocks
|
||||||
|
layer_local_block_ids = local_block_ids[-block_window:]
|
||||||
|
layer_remote_block_ids = remote_block_ids[-block_window:]
|
||||||
|
|
||||||
|
# Get descs ids for the layer.
|
||||||
|
layer_local_desc_ids = self._get_block_descs_ids(
|
||||||
|
self.engine_id, layer_local_block_ids, layer_idx)
|
||||||
|
layer_remote_desc_ids = self._get_block_descs_ids(
|
||||||
|
dst_engine_id, layer_remote_block_ids, layer_idx)
|
||||||
|
|
||||||
|
local_block_descs_ids.extend(layer_local_desc_ids)
|
||||||
|
remote_block_descs_ids.extend(layer_remote_desc_ids)
|
||||||
|
|
||||||
assert len(local_block_descs_ids) == len(remote_block_descs_ids)
|
assert len(local_block_descs_ids) == len(remote_block_descs_ids)
|
||||||
|
|
||||||
# Prepare transfer with Nixl.
|
# Prepare transfer with Nixl.
|
||||||
@ -721,12 +781,31 @@ class NixlConnectorWorker:
|
|||||||
# Use handle to check completion in future step().
|
# Use handle to check completion in future step().
|
||||||
self._recving_transfers[request_id].append(handle)
|
self._recving_transfers[request_id].append(handle)
|
||||||
|
|
||||||
def _get_block_descs_ids(self, engine_id: str,
|
def _get_block_descs_ids(self,
|
||||||
block_ids: list[int]) -> list[int]:
|
engine_id: str,
|
||||||
"""Get the descs ids for a set of block ids."""
|
block_ids: list[int],
|
||||||
|
layer_idx: Optional[int] = None) -> list[int]:
|
||||||
|
"""
|
||||||
|
Get the descs ids for a set of block ids.
|
||||||
|
If layer_idx is provided, we use the region_ids for the given layer.
|
||||||
|
Otherwise, we use all regions.
|
||||||
|
"""
|
||||||
|
|
||||||
# range(1) for MLA, range(2) otherwise.
|
if layer_idx is None:
|
||||||
region_ids = range(self.num_regions)
|
region_ids = range(self.num_regions)
|
||||||
|
else:
|
||||||
|
assert layer_idx < self.num_layers
|
||||||
|
if self.num_layers < self.num_regions:
|
||||||
|
# If we have more regions than layers, we assume that
|
||||||
|
# the regions are organized as [K0, V0, K1, V1, ...]
|
||||||
|
# and we select K_i and V_i
|
||||||
|
assert 2 * self.num_layers == self.num_regions
|
||||||
|
region_ids = range(2 * layer_idx, 2 * layer_idx + 2)
|
||||||
|
else:
|
||||||
|
# Otherwise, we assume we have MLA and select i-th layer
|
||||||
|
assert self.num_layers == self.num_regions
|
||||||
|
region_ids = range(layer_idx, layer_idx + 1)
|
||||||
|
|
||||||
num_blocks = self.dst_num_blocks[engine_id]
|
num_blocks = self.dst_num_blocks[engine_id]
|
||||||
|
|
||||||
# Compute the desc ids for each block.
|
# Compute the desc ids for each block.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user