[Feature][P/D]: Optimize NIXL Connector xfer Launch (#23887)

Signed-off-by: ycyaw66 <497410282@qq.com>
Co-authored-by: ycyaw66 <497410282@qq.com>
This commit is contained in:
WeiQing Chen 2025-09-04 03:14:30 +08:00 committed by GitHub
parent a742322092
commit 6adaed42f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -14,6 +14,7 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
import msgspec
import numpy as np
import torch
import zmq
@ -1191,8 +1192,8 @@ class NixlConnectorWorker:
# workers will issue xfers to parts of the P worker remote kv caches.
# Get descs ids.
local_block_descs_ids: list[int] = []
remote_block_descs_ids: list[int] = []
local_block_descs_ids: np.ndarray
remote_block_descs_ids: np.ndarray
if not self.block_window_per_layer:
# Default case: assume global attention
remote_block_descs_ids = self._get_block_descs_ids(
@ -1202,6 +1203,8 @@ class NixlConnectorWorker:
else:
# TODO(mgoin): remove this once we have hybrid memory allocator
# Optimization for models with local attention (Llama 4)
local_descs_list = []
remote_descs_list = []
for layer_idx, block_window in enumerate(
self.block_window_per_layer):
# For each layer:
@ -1221,8 +1224,11 @@ class NixlConnectorWorker:
layer_remote_desc_ids = self._get_block_descs_ids(
dst_engine_id, layer_remote_block_ids, layer_idx)
local_block_descs_ids.extend(layer_local_desc_ids)
remote_block_descs_ids.extend(layer_remote_desc_ids)
local_descs_list.append(layer_local_desc_ids)
remote_descs_list.append(layer_remote_desc_ids)
local_block_descs_ids = np.concatenate(local_descs_list)
remote_block_descs_ids = np.concatenate(remote_descs_list)
assert len(local_block_descs_ids) == len(remote_block_descs_ids)
@ -1247,14 +1253,14 @@ class NixlConnectorWorker:
def _get_block_descs_ids(self,
engine_id: str,
block_ids: list[int],
layer_idx: Optional[int] = None) -> list[int]:
layer_idx: Optional[int] = None) -> np.ndarray:
"""
Get the descs ids for a set of block ids.
If layer_idx is provided, we use the region_ids for the given layer.
Otherwise, we use all regions.
"""
if layer_idx is None:
region_ids = range(self.num_regions)
region_ids = np.arange(self.num_regions)
else:
assert layer_idx < self.num_layers
if self.num_layers < self.num_regions:
@ -1262,20 +1268,19 @@ class NixlConnectorWorker:
# the regions are organized as [K0, V0, K1, V1, ...]
# and we select K_i and V_i
assert 2 * self.num_layers == self.num_regions
region_ids = range(2 * layer_idx, 2 * layer_idx + 2)
region_ids = np.arange(2 * layer_idx, 2 * layer_idx + 2)
else:
# Otherwise, we assume we have MLA and select i-th layer
assert self.num_layers == self.num_regions
region_ids = range(layer_idx, layer_idx + 1)
region_ids = np.arange(layer_idx, layer_idx + 1)
num_blocks = self.dst_num_blocks[engine_id]
# Compute the desc ids for each block.
descs_ids: list[int] = []
for reg_id in region_ids:
for block_id in block_ids:
descs_ids.append(reg_id * num_blocks + block_id)
return descs_ids
region_ids = region_ids[:, None]
block_ids = np.array(block_ids)[None, :]
descs_ids = region_ids * num_blocks + block_ids
return descs_ids.flatten()
def get_backend_aware_kv_block_len(self):
"""