[Feature][P/D]: Optimize NIXL Connector xfer Launch (#23887)

Signed-off-by: ycyaw66 <497410282@qq.com>
Co-authored-by: ycyaw66 <497410282@qq.com>
This commit is contained in:
WeiQing Chen 2025-09-04 03:14:30 +08:00 committed by GitHub
parent a742322092
commit 6adaed42f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -14,6 +14,7 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
import msgspec import msgspec
import numpy as np
import torch import torch
import zmq import zmq
@ -1191,8 +1192,8 @@ class NixlConnectorWorker:
# workers will issue xfers to parts of the P worker remote kv caches. # workers will issue xfers to parts of the P worker remote kv caches.
# Get descs ids. # Get descs ids.
local_block_descs_ids: list[int] = [] local_block_descs_ids: np.ndarray
remote_block_descs_ids: list[int] = [] remote_block_descs_ids: np.ndarray
if not self.block_window_per_layer: if not self.block_window_per_layer:
# Default case: assume global attention # Default case: assume global attention
remote_block_descs_ids = self._get_block_descs_ids( remote_block_descs_ids = self._get_block_descs_ids(
@ -1202,6 +1203,8 @@ class NixlConnectorWorker:
else: else:
# TODO(mgoin): remove this once we have hybrid memory allocator # TODO(mgoin): remove this once we have hybrid memory allocator
# Optimization for models with local attention (Llama 4) # Optimization for models with local attention (Llama 4)
local_descs_list = []
remote_descs_list = []
for layer_idx, block_window in enumerate( for layer_idx, block_window in enumerate(
self.block_window_per_layer): self.block_window_per_layer):
# For each layer: # For each layer:
@ -1221,8 +1224,11 @@ class NixlConnectorWorker:
layer_remote_desc_ids = self._get_block_descs_ids( layer_remote_desc_ids = self._get_block_descs_ids(
dst_engine_id, layer_remote_block_ids, layer_idx) dst_engine_id, layer_remote_block_ids, layer_idx)
local_block_descs_ids.extend(layer_local_desc_ids) local_descs_list.append(layer_local_desc_ids)
remote_block_descs_ids.extend(layer_remote_desc_ids) remote_descs_list.append(layer_remote_desc_ids)
local_block_descs_ids = np.concatenate(local_descs_list)
remote_block_descs_ids = np.concatenate(remote_descs_list)
assert len(local_block_descs_ids) == len(remote_block_descs_ids) assert len(local_block_descs_ids) == len(remote_block_descs_ids)
@ -1247,14 +1253,14 @@ class NixlConnectorWorker:
def _get_block_descs_ids(self, def _get_block_descs_ids(self,
engine_id: str, engine_id: str,
block_ids: list[int], block_ids: list[int],
layer_idx: Optional[int] = None) -> list[int]: layer_idx: Optional[int] = None) -> np.ndarray:
""" """
Get the descs ids for a set of block ids. Get the descs ids for a set of block ids.
If layer_idx is provided, we use the region_ids for the given layer. If layer_idx is provided, we use the region_ids for the given layer.
Otherwise, we use all regions. Otherwise, we use all regions.
""" """
if layer_idx is None: if layer_idx is None:
region_ids = range(self.num_regions) region_ids = np.arange(self.num_regions)
else: else:
assert layer_idx < self.num_layers assert layer_idx < self.num_layers
if self.num_layers < self.num_regions: if self.num_layers < self.num_regions:
@ -1262,20 +1268,19 @@ class NixlConnectorWorker:
# the regions are organized as [K0, V0, K1, V1, ...] # the regions are organized as [K0, V0, K1, V1, ...]
# and we select K_i and V_i # and we select K_i and V_i
assert 2 * self.num_layers == self.num_regions assert 2 * self.num_layers == self.num_regions
region_ids = range(2 * layer_idx, 2 * layer_idx + 2) region_ids = np.arange(2 * layer_idx, 2 * layer_idx + 2)
else: else:
# Otherwise, we assume we have MLA and select i-th layer # Otherwise, we assume we have MLA and select i-th layer
assert self.num_layers == self.num_regions assert self.num_layers == self.num_regions
region_ids = range(layer_idx, layer_idx + 1) region_ids = np.arange(layer_idx, layer_idx + 1)
num_blocks = self.dst_num_blocks[engine_id] num_blocks = self.dst_num_blocks[engine_id]
# Compute the desc ids for each block. # Compute the desc ids for each block.
descs_ids: list[int] = [] region_ids = region_ids[:, None]
for reg_id in region_ids: block_ids = np.array(block_ids)[None, :]
for block_id in block_ids: descs_ids = region_ids * num_blocks + block_ids
descs_ids.append(reg_id * num_blocks + block_id) return descs_ids.flatten()
return descs_ids
def get_backend_aware_kv_block_len(self): def get_backend_aware_kv_block_len(self):
""" """