mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-27 20:30:54 +08:00
[Feature][P/D]: Optimize NIXL Connector xfer Launch (#23887)
Signed-off-by: ycyaw66 <497410282@qq.com> Co-authored-by: ycyaw66 <497410282@qq.com>
This commit is contained in:
parent
a742322092
commit
6adaed42f4
@ -14,6 +14,7 @@ from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import msgspec
|
||||
import numpy as np
|
||||
import torch
|
||||
import zmq
|
||||
|
||||
@ -1191,8 +1192,8 @@ class NixlConnectorWorker:
|
||||
# workers will issue xfers to parts of the P worker remote kv caches.
|
||||
|
||||
# Get descs ids.
|
||||
local_block_descs_ids: list[int] = []
|
||||
remote_block_descs_ids: list[int] = []
|
||||
local_block_descs_ids: np.ndarray
|
||||
remote_block_descs_ids: np.ndarray
|
||||
if not self.block_window_per_layer:
|
||||
# Default case: assume global attention
|
||||
remote_block_descs_ids = self._get_block_descs_ids(
|
||||
@ -1202,6 +1203,8 @@ class NixlConnectorWorker:
|
||||
else:
|
||||
# TODO(mgoin): remove this once we have hybrid memory allocator
|
||||
# Optimization for models with local attention (Llama 4)
|
||||
local_descs_list = []
|
||||
remote_descs_list = []
|
||||
for layer_idx, block_window in enumerate(
|
||||
self.block_window_per_layer):
|
||||
# For each layer:
|
||||
@ -1221,8 +1224,11 @@ class NixlConnectorWorker:
|
||||
layer_remote_desc_ids = self._get_block_descs_ids(
|
||||
dst_engine_id, layer_remote_block_ids, layer_idx)
|
||||
|
||||
local_block_descs_ids.extend(layer_local_desc_ids)
|
||||
remote_block_descs_ids.extend(layer_remote_desc_ids)
|
||||
local_descs_list.append(layer_local_desc_ids)
|
||||
remote_descs_list.append(layer_remote_desc_ids)
|
||||
|
||||
local_block_descs_ids = np.concatenate(local_descs_list)
|
||||
remote_block_descs_ids = np.concatenate(remote_descs_list)
|
||||
|
||||
assert len(local_block_descs_ids) == len(remote_block_descs_ids)
|
||||
|
||||
@ -1247,14 +1253,14 @@ class NixlConnectorWorker:
|
||||
def _get_block_descs_ids(self,
|
||||
engine_id: str,
|
||||
block_ids: list[int],
|
||||
layer_idx: Optional[int] = None) -> list[int]:
|
||||
layer_idx: Optional[int] = None) -> np.ndarray:
|
||||
"""
|
||||
Get the descs ids for a set of block ids.
|
||||
If layer_idx is provided, we use the region_ids for the given layer.
|
||||
Otherwise, we use all regions.
|
||||
"""
|
||||
if layer_idx is None:
|
||||
region_ids = range(self.num_regions)
|
||||
region_ids = np.arange(self.num_regions)
|
||||
else:
|
||||
assert layer_idx < self.num_layers
|
||||
if self.num_layers < self.num_regions:
|
||||
@ -1262,20 +1268,19 @@ class NixlConnectorWorker:
|
||||
# the regions are organized as [K0, V0, K1, V1, ...]
|
||||
# and we select K_i and V_i
|
||||
assert 2 * self.num_layers == self.num_regions
|
||||
region_ids = range(2 * layer_idx, 2 * layer_idx + 2)
|
||||
region_ids = np.arange(2 * layer_idx, 2 * layer_idx + 2)
|
||||
else:
|
||||
# Otherwise, we assume we have MLA and select i-th layer
|
||||
assert self.num_layers == self.num_regions
|
||||
region_ids = range(layer_idx, layer_idx + 1)
|
||||
region_ids = np.arange(layer_idx, layer_idx + 1)
|
||||
|
||||
num_blocks = self.dst_num_blocks[engine_id]
|
||||
|
||||
# Compute the desc ids for each block.
|
||||
descs_ids: list[int] = []
|
||||
for reg_id in region_ids:
|
||||
for block_id in block_ids:
|
||||
descs_ids.append(reg_id * num_blocks + block_id)
|
||||
return descs_ids
|
||||
region_ids = region_ids[:, None]
|
||||
block_ids = np.array(block_ids)[None, :]
|
||||
descs_ids = region_ids * num_blocks + block_ids
|
||||
return descs_ids.flatten()
|
||||
|
||||
def get_backend_aware_kv_block_len(self):
|
||||
"""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user